Skip to main content

yulang_native/
cps_optimize.rs

1//! Optimization entrypoint for backend-facing CPS representation ABI.
2//!
3//! This module is intentionally placed between CPS repr ABI lowering and
4//! Cranelift codegen so JIT and object generation share one transformation
5//! path.  Early passes are deliberately conservative: they rewrite explicit
6//! continuation call sites, but leave closure entries, thunk entries, and
7//! handler arm entries alone unless their call protocol is represented at the
8//! call site.
9
10use std::collections::{HashMap, HashSet, VecDeque};
11
12use crate::cps_ir::{
13    CpsContinuationId, CpsHandlerEnv, CpsRecordField, CpsShotKind, CpsStmt, CpsTerminator,
14    CpsValueId,
15};
16use crate::cps_repr_abi::{
17    CpsReprAbiContinuation, CpsReprAbiFunction, CpsReprAbiModule, CpsReprAbiValue,
18};
19use yulang_typed_ir as typed_ir;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CpsOptimizationOutput {
23    pub module: CpsReprAbiModule,
24    pub profile: CpsOptimizationProfile,
25}
26
27#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
28pub struct CpsOptimizationProfile {
29    pub functions: usize,
30    pub roots: usize,
31    pub continuations: usize,
32    pub handlers: usize,
33    pub statements: usize,
34    pub optimized_continuations: usize,
35    pub optimized_statements: usize,
36    pub passes_run: usize,
37    pub forwarded_continuation_calls: usize,
38    pub returned_continuation_calls: usize,
39    pub folded_constant_branches: usize,
40    pub rewritten_pure_effectful_calls: usize,
41    pub reified_primitive_calls: usize,
42    pub reified_partial_closure_calls: usize,
43    pub reified_known_closure_parameter_calls: usize,
44    pub removed_unused_continuation_params: usize,
45    pub folded_structural_projections: usize,
46    pub inlined_pure_direct_calls: usize,
47    pub inlined_continuation_calls: usize,
48    pub removed_unreachable_continuations: usize,
49    pub removed_dead_pure_statements: usize,
50    pub direct_style_islands: usize,
51    pub direct_style_continuations: usize,
52    pub changed: bool,
53}
54
55pub fn optimize_cps_repr_abi_module(module: &CpsReprAbiModule) -> CpsOptimizationOutput {
56    let mut output = CpsOptimizationOutput {
57        module: module.clone(),
58        profile: CpsOptimizationProfile::measure(module),
59    };
60
61    for _ in 0..4 {
62        if !run_simplification_round(&mut output) {
63            break;
64        }
65    }
66    output.profile.record_optimized_size(&output.module);
67    analyze_direct_style_islands(&mut output);
68    maybe_trace_profile(&output.profile);
69    output
70}
71
72fn run_simplification_round(output: &mut CpsOptimizationOutput) -> bool {
73    let before = output.profile;
74    rewrite_forwarding_continuation_calls(output);
75    rewrite_returning_continuation_calls(output);
76    fold_constant_branches(output);
77    rewrite_pure_effectful_calls(output);
78    reify_direct_primitive_calls(output);
79    reify_local_partial_closure_calls(output);
80    reify_known_closure_parameter_calls(output);
81    remove_unused_continuation_params(output);
82    fold_structural_projections(output);
83    inline_pure_direct_calls(output);
84    inline_single_use_continuation_calls(output);
85    reify_local_partial_closure_calls(output);
86    reify_known_closure_parameter_calls(output);
87    remove_unused_continuation_params(output);
88    prune_unreachable_continuations(output);
89    eliminate_dead_pure_statements(output);
90    prune_unreachable_continuations(output);
91    output.profile.has_more_changes_than(before)
92}
93
94fn rewrite_forwarding_continuation_calls(output: &mut CpsOptimizationOutput) {
95    output.profile.passes_run += 1;
96    for function in output
97        .module
98        .functions
99        .iter_mut()
100        .chain(&mut output.module.roots)
101    {
102        let forwarders = forwarding_continuations(function);
103        if forwarders.is_empty() {
104            continue;
105        }
106        for continuation in &mut function.continuations {
107            output.profile.forwarded_continuation_calls +=
108                rewrite_terminator_forwarders(&mut continuation.terminator, &forwarders);
109        }
110    }
111    output.profile.changed = output.profile.forwarded_continuation_calls > 0;
112}
113
114fn rewrite_returning_continuation_calls(output: &mut CpsOptimizationOutput) {
115    output.profile.passes_run += 1;
116    for function in output
117        .module
118        .functions
119        .iter_mut()
120        .chain(&mut output.module.roots)
121    {
122        let returners = returning_continuations(function);
123        if returners.is_empty() {
124            continue;
125        }
126        for continuation in &mut function.continuations {
127            output.profile.returned_continuation_calls +=
128                rewrite_terminator_returners(&mut continuation.terminator, &returners);
129        }
130    }
131    output.profile.changed |= output.profile.returned_continuation_calls > 0;
132}
133
134fn fold_constant_branches(output: &mut CpsOptimizationOutput) {
135    output.profile.passes_run += 1;
136    for function in output
137        .module
138        .functions
139        .iter_mut()
140        .chain(&mut output.module.roots)
141    {
142        let empty_param_continuations = function
143            .continuations
144            .iter()
145            .filter(|continuation| continuation.params.is_empty())
146            .map(|continuation| continuation.id)
147            .collect::<HashSet<_>>();
148        for continuation in &mut function.continuations {
149            output.profile.folded_constant_branches +=
150                fold_constant_branch_in_continuation(continuation, &empty_param_continuations);
151        }
152    }
153    output.profile.changed |= output.profile.folded_constant_branches > 0;
154}
155
156fn rewrite_pure_effectful_calls(output: &mut CpsOptimizationOutput) {
157    output.profile.passes_run += 1;
158    let pure_functions = pure_callable_functions(&output.module);
159    if pure_functions.is_empty() {
160        return;
161    }
162    for function in output
163        .module
164        .functions
165        .iter_mut()
166        .chain(&mut output.module.roots)
167    {
168        output.profile.rewritten_pure_effectful_calls +=
169            rewrite_pure_effectful_calls_in_function(function, &pure_functions);
170    }
171    output.profile.changed |= output.profile.rewritten_pure_effectful_calls > 0;
172}
173
174fn reify_direct_primitive_calls(output: &mut CpsOptimizationOutput) {
175    output.profile.passes_run += 1;
176    let primitives = primitive_wrappers(&output.module);
177    if primitives.is_empty() {
178        return;
179    }
180    for function in output
181        .module
182        .functions
183        .iter_mut()
184        .chain(&mut output.module.roots)
185    {
186        for continuation in &mut function.continuations {
187            for stmt in &mut continuation.stmts {
188                output.profile.reified_primitive_calls +=
189                    reify_direct_primitive_stmt(stmt, &primitives);
190            }
191        }
192    }
193    output.profile.changed |= output.profile.reified_primitive_calls > 0;
194}
195
196fn reify_local_partial_closure_calls(output: &mut CpsOptimizationOutput) {
197    output.profile.passes_run += 1;
198    for function in output
199        .module
200        .functions
201        .iter_mut()
202        .chain(&mut output.module.roots)
203    {
204        let wrappers = partial_closure_wrappers(function);
205        if wrappers.is_empty() {
206            continue;
207        }
208        let resumable = scalar_resume_continuations(function);
209        let mut next_value = next_function_value_id(function);
210        for continuation in &mut function.continuations {
211            output.profile.reified_partial_closure_calls +=
212                reify_local_partial_closure_calls_in_continuation(
213                    continuation,
214                    &wrappers,
215                    &resumable,
216                    &mut next_value,
217                );
218        }
219    }
220    output.profile.changed |= output.profile.reified_partial_closure_calls > 0;
221}
222
223fn reify_known_closure_parameter_calls(output: &mut CpsOptimizationOutput) {
224    output.profile.passes_run += 1;
225    for function in output
226        .module
227        .functions
228        .iter_mut()
229        .chain(&mut output.module.roots)
230    {
231        let wrappers = partial_closure_wrappers(function);
232        if wrappers.is_empty() {
233            continue;
234        }
235        output.profile.reified_known_closure_parameter_calls +=
236            reify_known_closure_parameter_calls_in_function(function, &wrappers);
237    }
238    output.profile.changed |= output.profile.reified_known_closure_parameter_calls > 0;
239}
240
241fn remove_unused_continuation_params(output: &mut CpsOptimizationOutput) {
242    output.profile.passes_run += 1;
243    for function in output
244        .module
245        .functions
246        .iter_mut()
247        .chain(&mut output.module.roots)
248    {
249        output.profile.removed_unused_continuation_params +=
250            remove_unused_continuation_params_in_function(function);
251    }
252    output.profile.changed |= output.profile.removed_unused_continuation_params > 0;
253}
254
255fn fold_structural_projections(output: &mut CpsOptimizationOutput) {
256    output.profile.passes_run += 1;
257    for function in output
258        .module
259        .functions
260        .iter_mut()
261        .chain(&mut output.module.roots)
262    {
263        for continuation in &mut function.continuations {
264            output.profile.folded_structural_projections +=
265                fold_structural_projections_in_continuation(continuation);
266        }
267    }
268    output.profile.changed |= output.profile.folded_structural_projections > 0;
269}
270
271fn inline_pure_direct_calls(output: &mut CpsOptimizationOutput) {
272    output.profile.passes_run += 1;
273    let candidates = pure_direct_inline_candidates(&output.module);
274    if candidates.is_empty() {
275        return;
276    }
277    for function in output
278        .module
279        .functions
280        .iter_mut()
281        .chain(&mut output.module.roots)
282    {
283        output.profile.inlined_pure_direct_calls +=
284            inline_pure_direct_calls_in_function(function, &candidates);
285    }
286    output.profile.changed |= output.profile.inlined_pure_direct_calls > 0;
287}
288
289fn remove_unused_continuation_params_in_function(function: &mut CpsReprAbiFunction) -> usize {
290    let unused_slots = unused_continuation_param_slots(function);
291    if unused_slots.is_empty() {
292        return 0;
293    }
294
295    let mut removed = 0;
296    for continuation in &mut function.continuations {
297        if let Some(slots) = unused_slots.get(&continuation.id) {
298            removed += remove_indexed_items(&mut continuation.params, slots);
299        }
300        if let CpsTerminator::Continue { target, args } = &mut continuation.terminator {
301            if let Some(slots) = unused_slots.get(target) {
302                remove_indexed_items(args, slots);
303            }
304        }
305    }
306    removed
307}
308
309fn unused_continuation_param_slots(
310    function: &CpsReprAbiFunction,
311) -> HashMap<CpsContinuationId, HashSet<usize>> {
312    let references = continuation_references(function);
313    let protected = protected_continuations(function);
314    let used_values = function_used_values(function);
315
316    function
317        .continuations
318        .iter()
319        .filter(|continuation| !protected.contains(&continuation.id))
320        .filter(|continuation| {
321            references
322                .get(&continuation.id)
323                .is_some_and(|reference| reference.total == reference.continue_calls)
324        })
325        .filter_map(|continuation| {
326            let slots = continuation
327                .params
328                .iter()
329                .enumerate()
330                .filter_map(|(index, param)| (!used_values.contains(&param.value)).then_some(index))
331                .collect::<HashSet<_>>();
332            (!slots.is_empty()).then_some((continuation.id, slots))
333        })
334        .collect()
335}
336
337fn function_used_values(function: &CpsReprAbiFunction) -> HashSet<CpsValueId> {
338    let mut used = HashSet::new();
339    for continuation in &function.continuations {
340        used.extend(continuation.environment.iter().map(|slot| slot.value));
341        for stmt in &continuation.stmts {
342            used.extend(stmt_operands(stmt));
343        }
344        used.extend(terminator_values(&continuation.terminator));
345    }
346    used
347}
348
349fn remove_indexed_items<T>(items: &mut Vec<T>, slots: &HashSet<usize>) -> usize {
350    let before = items.len();
351    let mut index = 0;
352    items.retain(|_| {
353        let keep = !slots.contains(&index);
354        index += 1;
355        keep
356    });
357    before - items.len()
358}
359
360fn inline_single_use_continuation_calls(output: &mut CpsOptimizationOutput) {
361    output.profile.passes_run += 1;
362    for function in output
363        .module
364        .functions
365        .iter_mut()
366        .chain(&mut output.module.roots)
367    {
368        let candidates = inline_candidates(function);
369        if candidates.is_empty() {
370            continue;
371        }
372        for index in 0..function.continuations.len() {
373            output.profile.inlined_continuation_calls +=
374                inline_continuation_call_at(function, index, &candidates);
375        }
376    }
377    output.profile.changed |= output.profile.inlined_continuation_calls > 0;
378}
379
380fn prune_unreachable_continuations(output: &mut CpsOptimizationOutput) {
381    output.profile.passes_run += 1;
382    for function in output
383        .module
384        .functions
385        .iter_mut()
386        .chain(&mut output.module.roots)
387    {
388        let reachable = reachable_continuations(function);
389        let before = function.continuations.len();
390        function
391            .continuations
392            .retain(|continuation| reachable.contains(&continuation.id));
393        output.profile.removed_unreachable_continuations += before - function.continuations.len();
394    }
395    output.profile.changed |= output.profile.removed_unreachable_continuations > 0;
396}
397
398fn eliminate_dead_pure_statements(output: &mut CpsOptimizationOutput) {
399    output.profile.passes_run += 1;
400    for function in output
401        .module
402        .functions
403        .iter_mut()
404        .chain(&mut output.module.roots)
405    {
406        let captured_values = function_captured_values(function);
407        for continuation in &mut function.continuations {
408            output.profile.removed_dead_pure_statements +=
409                eliminate_dead_pure_statements_in_continuation(continuation, &captured_values);
410        }
411    }
412    output.profile.changed |= output.profile.removed_dead_pure_statements > 0;
413}
414
415fn analyze_direct_style_islands(output: &mut CpsOptimizationOutput) {
416    output.profile.direct_style_islands = 0;
417    output.profile.direct_style_continuations = 0;
418    for function in output.module.functions.iter().chain(&output.module.roots) {
419        let islands = direct_style_islands(function);
420        output.profile.direct_style_islands += islands.len();
421        output.profile.direct_style_continuations += islands
422            .iter()
423            .map(|island| island.continuations.len())
424            .sum::<usize>();
425    }
426}
427
428fn maybe_trace_profile(profile: &CpsOptimizationProfile) {
429    if std::env::var_os("YULANG_CPS_OPT_TRACE").is_none() {
430        return;
431    }
432    eprintln!(
433        "[CPS-OPT] functions={} roots={} continuations={} optimized_continuations={} handlers={} statements={} optimized_statements={} passes={} forwarded_continuation_calls={} returned_continuation_calls={} folded_constant_branches={} rewritten_pure_effectful_calls={} reified_primitive_calls={} reified_partial_closure_calls={} reified_known_closure_parameter_calls={} removed_unused_continuation_params={} folded_structural_projections={} inlined_pure_direct_calls={} inlined_continuation_calls={} removed_unreachable_continuations={} removed_dead_pure_statements={} direct_style_islands={} direct_style_continuations={} changed={}",
434        profile.functions,
435        profile.roots,
436        profile.continuations,
437        profile.optimized_continuations,
438        profile.handlers,
439        profile.statements,
440        profile.optimized_statements,
441        profile.passes_run,
442        profile.forwarded_continuation_calls,
443        profile.returned_continuation_calls,
444        profile.folded_constant_branches,
445        profile.rewritten_pure_effectful_calls,
446        profile.reified_primitive_calls,
447        profile.reified_partial_closure_calls,
448        profile.reified_known_closure_parameter_calls,
449        profile.removed_unused_continuation_params,
450        profile.folded_structural_projections,
451        profile.inlined_pure_direct_calls,
452        profile.inlined_continuation_calls,
453        profile.removed_unreachable_continuations,
454        profile.removed_dead_pure_statements,
455        profile.direct_style_islands,
456        profile.direct_style_continuations,
457        profile.changed
458    );
459}
460
461fn primitive_wrappers(module: &CpsReprAbiModule) -> HashMap<String, PrimitiveWrapper> {
462    module
463        .functions
464        .iter()
465        .chain(&module.roots)
466        .filter_map(primitive_wrapper)
467        .collect()
468}
469
470fn primitive_wrapper(function: &CpsReprAbiFunction) -> Option<(String, PrimitiveWrapper)> {
471    if !function.handlers.is_empty() {
472        return None;
473    }
474    let continuation = function
475        .continuations
476        .iter()
477        .find(|continuation| continuation.id == function.entry)?;
478    if !continuation.environment.is_empty() || continuation.stmts.len() != 1 {
479        return None;
480    }
481    let [CpsStmt::Primitive { dest, op, args }] = continuation.stmts.as_slice() else {
482        return None;
483    };
484    if !matches!(continuation.terminator, CpsTerminator::Return(value) if value == *dest) {
485        return None;
486    }
487    let params = continuation
488        .params
489        .iter()
490        .map(|param| param.value)
491        .collect::<Vec<_>>();
492    if function
493        .params
494        .iter()
495        .map(|param| param.value)
496        .collect::<Vec<_>>()
497        != params
498    {
499        return None;
500    }
501    if *args != params {
502        return None;
503    }
504    Some((
505        function.name.clone(),
506        PrimitiveWrapper {
507            op: *op,
508            arity: params.len(),
509        },
510    ))
511}
512
513fn reify_direct_primitive_stmt(
514    stmt: &mut CpsStmt,
515    primitives: &HashMap<String, PrimitiveWrapper>,
516) -> usize {
517    let CpsStmt::DirectCall { dest, target, args } = stmt else {
518        return 0;
519    };
520    let Some(primitive) = primitives.get(target) else {
521        return 0;
522    };
523    if primitive.arity != args.len() {
524        return 0;
525    }
526    *stmt = CpsStmt::Primitive {
527        dest: *dest,
528        op: primitive.op,
529        args: args.clone(),
530    };
531    1
532}
533
534#[derive(Debug, Clone, Copy, PartialEq, Eq)]
535struct PrimitiveWrapper {
536    op: typed_ir::PrimitiveOp,
537    arity: usize,
538}
539
540fn pure_callable_functions(module: &CpsReprAbiModule) -> HashSet<String> {
541    module
542        .functions
543        .iter()
544        .filter(|function| function_is_pure_callable(function))
545        .map(|function| function.name.clone())
546        .collect()
547}
548
549fn function_is_pure_callable(function: &CpsReprAbiFunction) -> bool {
550    function.handlers.is_empty()
551        && function
552            .continuations
553            .iter()
554            .all(|continuation| continuation.environment.is_empty())
555        && function
556            .continuations
557            .iter()
558            .flat_map(|continuation| &continuation.stmts)
559            .all(stmt_is_direct_call_safe)
560        && function
561            .continuations
562            .iter()
563            .all(|continuation| terminator_is_direct_call_safe(&continuation.terminator))
564}
565
566fn stmt_is_direct_call_safe(stmt: &CpsStmt) -> bool {
567    matches!(
568        stmt,
569        CpsStmt::Literal { .. }
570            | CpsStmt::Tuple { .. }
571            | CpsStmt::Record { .. }
572            | CpsStmt::RecordWithoutFields { .. }
573            | CpsStmt::Variant { .. }
574            | CpsStmt::Select { .. }
575            | CpsStmt::SelectWithDefault { .. }
576            | CpsStmt::RecordHasField { .. }
577            | CpsStmt::TupleGet { .. }
578            | CpsStmt::VariantTagEq { .. }
579            | CpsStmt::VariantPayload { .. }
580            | CpsStmt::Primitive { .. }
581            | CpsStmt::DirectCall { .. }
582    )
583}
584
585fn terminator_is_direct_call_safe(terminator: &CpsTerminator) -> bool {
586    matches!(
587        terminator,
588        CpsTerminator::Return(_) | CpsTerminator::Continue { .. } | CpsTerminator::Branch { .. }
589    )
590}
591
592fn rewrite_pure_effectful_calls_in_function(
593    function: &mut CpsReprAbiFunction,
594    pure_functions: &HashSet<String>,
595) -> usize {
596    let resumable = scalar_resume_continuations(function);
597    let mut next_value = next_function_value_id(function);
598    let mut count = 0;
599
600    for continuation in &mut function.continuations {
601        let CpsTerminator::EffectfulCall {
602            target,
603            args,
604            resume,
605        } = &continuation.terminator
606        else {
607            continue;
608        };
609        if !pure_functions.contains(target) || !resumable.contains(resume) {
610            continue;
611        }
612        let dest = next_value;
613        next_value.0 += 1;
614        continuation.stmts.push(CpsStmt::DirectCall {
615            dest,
616            target: target.clone(),
617            args: args.clone(),
618        });
619        continuation.terminator = CpsTerminator::Continue {
620            target: *resume,
621            args: vec![dest],
622        };
623        count += 1;
624    }
625
626    count
627}
628
629fn fold_constant_branch_in_continuation(
630    continuation: &mut CpsReprAbiContinuation,
631    empty_param_continuations: &HashSet<CpsContinuationId>,
632) -> usize {
633    let (cond, then_cont, else_cont) = match &continuation.terminator {
634        CpsTerminator::Branch {
635            cond,
636            then_cont,
637            else_cont,
638        } => (*cond, *then_cont, *else_cont),
639        _ => return 0,
640    };
641    let Some(value) = local_bool_literal(continuation, cond) else {
642        return 0;
643    };
644    let target = if value { then_cont } else { else_cont };
645    if !empty_param_continuations.contains(&target) {
646        return 0;
647    }
648    continuation.terminator = CpsTerminator::Continue {
649        target,
650        args: Vec::new(),
651    };
652    1
653}
654
655fn local_bool_literal(continuation: &CpsReprAbiContinuation, value: CpsValueId) -> Option<bool> {
656    continuation.stmts.iter().find_map(|stmt| match stmt {
657        CpsStmt::Literal {
658            dest,
659            literal: crate::cps_ir::CpsLiteral::Bool(bool_value),
660        } if *dest == value => Some(*bool_value),
661        _ => None,
662    })
663}
664
665fn scalar_resume_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
666    function
667        .continuations
668        .iter()
669        .filter(|continuation| {
670            continuation.environment.is_empty() && continuation.params.len() == 1
671        })
672        .map(|continuation| continuation.id)
673        .collect()
674}
675
676fn partial_closure_wrappers(
677    function: &CpsReprAbiFunction,
678) -> HashMap<CpsContinuationId, PartialClosureWrapper> {
679    function
680        .continuations
681        .iter()
682        .filter_map(partial_closure_wrapper)
683        .collect()
684}
685
686fn partial_closure_wrapper(
687    continuation: &CpsReprAbiContinuation,
688) -> Option<(CpsContinuationId, PartialClosureWrapper)> {
689    if continuation.params.len() != 1 || continuation.stmts.len() != 1 {
690        return None;
691    }
692    let [stmt] = continuation.stmts.as_slice() else {
693        return None;
694    };
695    let Some((dest, call, args)) = partial_closure_call_shape(stmt) else {
696        return None;
697    };
698    if !matches!(continuation.terminator, CpsTerminator::Return(value) if value == dest) {
699        return None;
700    }
701    let captured = continuation
702        .environment
703        .iter()
704        .map(|slot| slot.value)
705        .collect::<Vec<_>>();
706    let param = continuation.params[0].value;
707    if args.len() != captured.len() + 1 {
708        return None;
709    }
710    if args[..captured.len()] != captured {
711        return None;
712    }
713    if args[captured.len()] != param {
714        return None;
715    }
716    Some((continuation.id, PartialClosureWrapper { call, captured }))
717}
718
719fn partial_closure_call_shape(
720    stmt: &CpsStmt,
721) -> Option<(CpsValueId, PartialClosureCall, &[CpsValueId])> {
722    match stmt {
723        CpsStmt::DirectCall { dest, target, args } => Some((
724            *dest,
725            PartialClosureCall::Direct {
726                target: target.clone(),
727            },
728            args,
729        )),
730        CpsStmt::Primitive { dest, op, args } => {
731            Some((*dest, PartialClosureCall::Primitive { op: *op }, args))
732        }
733        _ => None,
734    }
735}
736
737fn reify_local_partial_closure_calls_in_continuation(
738    continuation: &mut CpsReprAbiContinuation,
739    wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
740    resumable: &HashSet<CpsContinuationId>,
741    next_value: &mut CpsValueId,
742) -> usize {
743    reify_partial_closure_calls_in_continuation(
744        continuation,
745        wrappers,
746        &HashMap::new(),
747        resumable,
748        next_value,
749    )
750}
751
752fn reify_known_closure_parameter_calls_in_function(
753    function: &mut CpsReprAbiFunction,
754    wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
755) -> usize {
756    let closure_values = local_closure_values(function, wrappers);
757    if closure_values.is_empty() {
758        return 0;
759    }
760    let parameter_wrappers = known_closure_parameter_wrappers(function, &closure_values);
761    if parameter_wrappers.is_empty() {
762        return 0;
763    }
764
765    let resumable = scalar_resume_continuations(function);
766    let mut next_value = next_function_value_id(function);
767    let mut count = 0;
768    for continuation in &mut function.continuations {
769        let Some(initial_closures) = parameter_wrappers.get(&continuation.id) else {
770            continue;
771        };
772        count += reify_partial_closure_calls_in_continuation(
773            continuation,
774            wrappers,
775            initial_closures,
776            &resumable,
777            &mut next_value,
778        );
779    }
780    count
781}
782
783fn reify_partial_closure_calls_in_continuation(
784    continuation: &mut CpsReprAbiContinuation,
785    wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
786    initial_closures: &HashMap<CpsValueId, PartialClosureWrapper>,
787    resumable: &HashSet<CpsContinuationId>,
788    next_value: &mut CpsValueId,
789) -> usize {
790    let mut closures = initial_closures.clone();
791    let mut count = 0;
792    for stmt in &mut continuation.stmts {
793        match stmt {
794            CpsStmt::MakeClosure { dest, entry } => {
795                if let Some(wrapper) = wrappers.get(entry) {
796                    closures.insert(*dest, wrapper.clone());
797                }
798            }
799            CpsStmt::MakeRecursiveClosure { dest, .. } => {
800                closures.remove(dest);
801            }
802            CpsStmt::ApplyClosure { dest, closure, arg } => {
803                let Some(wrapper) = closures.get(closure) else {
804                    continue;
805                };
806                let mut args = wrapper.captured.clone();
807                args.push(*arg);
808                *stmt = wrapper.call.to_stmt(*dest, args);
809                count += 1;
810            }
811            _ => {}
812        }
813    }
814    count += reify_partial_closure_terminator(
815        &mut continuation.stmts,
816        &mut continuation.terminator,
817        &closures,
818        resumable,
819        next_value,
820    );
821    count
822}
823
824fn reify_partial_closure_terminator(
825    stmts: &mut Vec<CpsStmt>,
826    terminator: &mut CpsTerminator,
827    closures: &HashMap<CpsValueId, PartialClosureWrapper>,
828    resumable: &HashSet<CpsContinuationId>,
829    next_value: &mut CpsValueId,
830) -> usize {
831    let (closure, arg, resume) = match terminator {
832        CpsTerminator::EffectfulApply {
833            closure,
834            arg,
835            resume,
836        } => (*closure, *arg, *resume),
837        _ => return 0,
838    };
839    let Some(wrapper) = closures.get(&closure) else {
840        return 0;
841    };
842    let mut args = wrapper.captured.clone();
843    args.push(arg);
844    match &wrapper.call {
845        PartialClosureCall::Direct { target } => {
846            *terminator = CpsTerminator::EffectfulCall {
847                target: target.clone(),
848                args,
849                resume,
850            };
851            1
852        }
853        PartialClosureCall::Primitive { op } if resumable.contains(&resume) => {
854            let dest = *next_value;
855            next_value.0 += 1;
856            stmts.push(CpsStmt::Primitive {
857                dest,
858                op: *op,
859                args,
860            });
861            *terminator = CpsTerminator::Continue {
862                target: resume,
863                args: vec![dest],
864            };
865            1
866        }
867        PartialClosureCall::Primitive { .. } => 0,
868    }
869}
870
871fn local_closure_values(
872    function: &CpsReprAbiFunction,
873    wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
874) -> HashMap<CpsValueId, PartialClosureWrapper> {
875    let mut closures = HashMap::new();
876    for continuation in &function.continuations {
877        for stmt in &continuation.stmts {
878            match stmt {
879                CpsStmt::MakeClosure { dest, entry } => {
880                    let Some(wrapper) = wrappers.get(entry) else {
881                        continue;
882                    };
883                    closures.insert(*dest, wrapper.clone());
884                }
885                CpsStmt::MakeRecursiveClosure { dest, .. } => {
886                    closures.remove(dest);
887                }
888                _ => {}
889            }
890        }
891    }
892    closures
893}
894
895fn known_closure_parameter_wrappers(
896    function: &CpsReprAbiFunction,
897    closure_values: &HashMap<CpsValueId, PartialClosureWrapper>,
898) -> HashMap<CpsContinuationId, HashMap<CpsValueId, PartialClosureWrapper>> {
899    let continuations = function
900        .continuations
901        .iter()
902        .map(|continuation| (continuation.id, continuation))
903        .collect::<HashMap<_, _>>();
904    let references = continuation_references(function);
905    let protected = protected_continuations(function);
906    let mut candidates: HashMap<CpsContinuationId, Vec<KnownClosureParameterCandidate>> =
907        HashMap::new();
908    let mut blocked = HashSet::new();
909
910    for continuation in &function.continuations {
911        let CpsTerminator::Continue { target, args } = &continuation.terminator else {
912            continue;
913        };
914        if protected.contains(target) {
915            continue;
916        }
917        let Some(target_continuation) = continuations.get(target) else {
918            continue;
919        };
920        let Some(reference) = references.get(target) else {
921            continue;
922        };
923        if reference.total != reference.continue_calls
924            || args.len() != target_continuation.params.len()
925        {
926            blocked.insert(*target);
927            continue;
928        }
929
930        let slots = candidates.entry(*target).or_insert_with(|| {
931            vec![KnownClosureParameterCandidate::Unseen; target_continuation.params.len()]
932        });
933        for (index, arg) in args.iter().enumerate() {
934            let adapted = closure_values
935                .get(arg)
936                .and_then(|wrapper| wrapper.rebase_for_continue(args, &target_continuation.params));
937            slots[index].merge(adapted);
938        }
939    }
940
941    blocked.into_iter().for_each(|target| {
942        candidates.remove(&target);
943    });
944
945    candidates
946        .into_iter()
947        .filter_map(|(target, slots)| {
948            let continuation = continuations.get(&target)?;
949            let known = continuation
950                .params
951                .iter()
952                .zip(slots)
953                .filter_map(|(param, slot)| match slot {
954                    KnownClosureParameterCandidate::Known(wrapper) => Some((param.value, wrapper)),
955                    KnownClosureParameterCandidate::Unseen
956                    | KnownClosureParameterCandidate::Conflict => None,
957                })
958                .collect::<HashMap<_, _>>();
959            (!known.is_empty()).then_some((target, known))
960        })
961        .collect()
962}
963
964#[derive(Debug, Clone, PartialEq, Eq)]
965enum KnownClosureParameterCandidate {
966    Unseen,
967    Known(PartialClosureWrapper),
968    Conflict,
969}
970
971impl KnownClosureParameterCandidate {
972    fn merge(&mut self, wrapper: Option<PartialClosureWrapper>) {
973        let Some(wrapper) = wrapper else {
974            *self = Self::Conflict;
975            return;
976        };
977        match self {
978            Self::Unseen => *self = Self::Known(wrapper),
979            Self::Known(current) if *current == wrapper => {}
980            Self::Known(_) | Self::Conflict => *self = Self::Conflict,
981        }
982    }
983}
984
985#[derive(Debug, Clone, PartialEq, Eq)]
986struct PartialClosureWrapper {
987    call: PartialClosureCall,
988    captured: Vec<CpsValueId>,
989}
990
991impl PartialClosureWrapper {
992    fn rebase_for_continue(
993        &self,
994        supplied_args: &[CpsValueId],
995        target_params: &[CpsReprAbiValue],
996    ) -> Option<Self> {
997        if supplied_args.len() != target_params.len() {
998            return None;
999        }
1000        let captured = self
1001            .captured
1002            .iter()
1003            .map(|captured| {
1004                supplied_args
1005                    .iter()
1006                    .position(|arg| arg == captured)
1007                    .map(|index| target_params[index].value)
1008            })
1009            .collect::<Option<Vec<_>>>()?;
1010        Some(Self {
1011            call: self.call.clone(),
1012            captured,
1013        })
1014    }
1015}
1016
1017#[derive(Debug, Clone, PartialEq, Eq)]
1018enum PartialClosureCall {
1019    Direct { target: String },
1020    Primitive { op: typed_ir::PrimitiveOp },
1021}
1022
1023impl PartialClosureCall {
1024    fn to_stmt(&self, dest: CpsValueId, args: Vec<CpsValueId>) -> CpsStmt {
1025        match self {
1026            PartialClosureCall::Direct { target } => CpsStmt::DirectCall {
1027                dest,
1028                target: target.clone(),
1029                args,
1030            },
1031            PartialClosureCall::Primitive { op } => CpsStmt::Primitive {
1032                dest,
1033                op: *op,
1034                args,
1035            },
1036        }
1037    }
1038}
1039
1040fn pure_direct_inline_candidates(module: &CpsReprAbiModule) -> HashMap<String, PureDirectInline> {
1041    module
1042        .functions
1043        .iter()
1044        .filter_map(pure_direct_inline_candidate)
1045        .collect()
1046}
1047
1048fn pure_direct_inline_candidate(
1049    function: &CpsReprAbiFunction,
1050) -> Option<(String, PureDirectInline)> {
1051    if !function.handlers.is_empty() || function.continuations.len() != 1 {
1052        return None;
1053    }
1054    let continuation = function
1055        .continuations
1056        .iter()
1057        .find(|continuation| continuation.id == function.entry)?;
1058    if !continuation.environment.is_empty() || continuation.stmts.len() > 16 {
1059        return None;
1060    }
1061    if continuation.params.len() != function.params.len() {
1062        return None;
1063    }
1064    if continuation
1065        .params
1066        .iter()
1067        .map(|param| param.value)
1068        .ne(function.params.iter().map(|param| param.value))
1069    {
1070        return None;
1071    }
1072    if !continuation.stmts.iter().all(pure_direct_inline_stmt) {
1073        return None;
1074    }
1075    let CpsTerminator::Return(result) = continuation.terminator else {
1076        return None;
1077    };
1078    if !continuation
1079        .stmts
1080        .iter()
1081        .any(|stmt| stmt_dest(stmt) == Some(result))
1082    {
1083        return None;
1084    }
1085    Some((
1086        function.name.clone(),
1087        PureDirectInline {
1088            params: continuation
1089                .params
1090                .iter()
1091                .map(|param| param.value)
1092                .collect(),
1093            stmts: continuation.stmts.clone(),
1094            result,
1095        },
1096    ))
1097}
1098
1099fn pure_direct_inline_stmt(stmt: &CpsStmt) -> bool {
1100    matches!(
1101        stmt,
1102        CpsStmt::Literal { .. }
1103            | CpsStmt::Tuple { .. }
1104            | CpsStmt::Record { .. }
1105            | CpsStmt::RecordWithoutFields { .. }
1106            | CpsStmt::Variant { .. }
1107            | CpsStmt::Select { .. }
1108            | CpsStmt::SelectWithDefault { .. }
1109            | CpsStmt::RecordHasField { .. }
1110            | CpsStmt::TupleGet { .. }
1111            | CpsStmt::VariantTagEq { .. }
1112            | CpsStmt::VariantPayload { .. }
1113            | CpsStmt::Primitive { .. }
1114    )
1115}
1116
1117fn inline_pure_direct_calls_in_function(
1118    function: &mut CpsReprAbiFunction,
1119    candidates: &HashMap<String, PureDirectInline>,
1120) -> usize {
1121    let mut next_value = next_function_value_id(function);
1122    let mut count = 0;
1123    for continuation in &mut function.continuations {
1124        let mut stmts = Vec::with_capacity(continuation.stmts.len());
1125        for stmt in continuation.stmts.drain(..) {
1126            let CpsStmt::DirectCall { dest, target, args } = &stmt else {
1127                stmts.push(stmt);
1128                continue;
1129            };
1130            let Some(candidate) = candidates.get(target) else {
1131                stmts.push(stmt);
1132                continue;
1133            };
1134            if candidate.params.len() != args.len() {
1135                stmts.push(stmt);
1136                continue;
1137            }
1138            let mut substitution = candidate
1139                .params
1140                .iter()
1141                .copied()
1142                .zip(args.iter().copied())
1143                .collect::<HashMap<_, _>>();
1144            for stmt in &candidate.stmts {
1145                if let Some(value) = stmt_dest(stmt) {
1146                    substitution.entry(value).or_insert_with(|| {
1147                        let fresh = next_value;
1148                        next_value.0 += 1;
1149                        fresh
1150                    });
1151                }
1152            }
1153            substitution.insert(candidate.result, *dest);
1154            stmts.extend(
1155                candidate
1156                    .stmts
1157                    .iter()
1158                    .cloned()
1159                    .map(|stmt| substitute_pure_inline_stmt_values(stmt, &substitution)),
1160            );
1161            count += 1;
1162        }
1163        continuation.stmts = stmts;
1164    }
1165    count
1166}
1167
1168fn substitute_pure_inline_stmt_values(
1169    stmt: CpsStmt,
1170    substitution: &HashMap<CpsValueId, CpsValueId>,
1171) -> CpsStmt {
1172    match stmt {
1173        CpsStmt::Literal { dest, literal } => CpsStmt::Literal {
1174            dest: subst_value(dest, substitution),
1175            literal,
1176        },
1177        CpsStmt::Tuple { dest, items } => CpsStmt::Tuple {
1178            dest: subst_value(dest, substitution),
1179            items: subst_values(items, substitution),
1180        },
1181        CpsStmt::Record { dest, base, fields } => CpsStmt::Record {
1182            dest: subst_value(dest, substitution),
1183            base: base.map(|value| subst_value(value, substitution)),
1184            fields: fields
1185                .into_iter()
1186                .map(|field| CpsRecordField {
1187                    name: field.name,
1188                    value: subst_value(field.value, substitution),
1189                })
1190                .collect(),
1191        },
1192        CpsStmt::RecordWithoutFields { dest, base, fields } => CpsStmt::RecordWithoutFields {
1193            dest: subst_value(dest, substitution),
1194            base: subst_value(base, substitution),
1195            fields,
1196        },
1197        CpsStmt::Variant { dest, tag, value } => CpsStmt::Variant {
1198            dest: subst_value(dest, substitution),
1199            tag,
1200            value: value.map(|value| subst_value(value, substitution)),
1201        },
1202        CpsStmt::Select { dest, base, field } => CpsStmt::Select {
1203            dest: subst_value(dest, substitution),
1204            base: subst_value(base, substitution),
1205            field,
1206        },
1207        CpsStmt::SelectWithDefault {
1208            dest,
1209            base,
1210            field,
1211            default,
1212        } => CpsStmt::SelectWithDefault {
1213            dest: subst_value(dest, substitution),
1214            base: subst_value(base, substitution),
1215            field,
1216            default: subst_value(default, substitution),
1217        },
1218        CpsStmt::RecordHasField { dest, base, field } => CpsStmt::RecordHasField {
1219            dest: subst_value(dest, substitution),
1220            base: subst_value(base, substitution),
1221            field,
1222        },
1223        CpsStmt::TupleGet { dest, tuple, index } => CpsStmt::TupleGet {
1224            dest: subst_value(dest, substitution),
1225            tuple: subst_value(tuple, substitution),
1226            index,
1227        },
1228        CpsStmt::VariantTagEq { dest, variant, tag } => CpsStmt::VariantTagEq {
1229            dest: subst_value(dest, substitution),
1230            variant: subst_value(variant, substitution),
1231            tag,
1232        },
1233        CpsStmt::VariantPayload { dest, variant } => CpsStmt::VariantPayload {
1234            dest: subst_value(dest, substitution),
1235            variant: subst_value(variant, substitution),
1236        },
1237        CpsStmt::Primitive { dest, op, args } => CpsStmt::Primitive {
1238            dest: subst_value(dest, substitution),
1239            op,
1240            args: subst_values(args, substitution),
1241        },
1242        stmt => stmt,
1243    }
1244}
1245
1246fn next_function_value_id(function: &CpsReprAbiFunction) -> CpsValueId {
1247    let max_value = function
1248        .params
1249        .iter()
1250        .map(|value| value.value)
1251        .chain(
1252            function
1253                .continuations
1254                .iter()
1255                .flat_map(continuation_value_ids),
1256        )
1257        .map(|value| value.0)
1258        .max()
1259        .unwrap_or(0);
1260    CpsValueId(max_value + 1)
1261}
1262
1263fn continuation_value_ids(
1264    continuation: &CpsReprAbiContinuation,
1265) -> impl Iterator<Item = CpsValueId> + '_ {
1266    continuation
1267        .params
1268        .iter()
1269        .map(|value| value.value)
1270        .chain(continuation.environment.iter().map(|slot| slot.value))
1271        .chain(continuation.stmts.iter().filter_map(stmt_dest))
1272}
1273
1274fn fold_structural_projections_in_continuation(continuation: &mut CpsReprAbiContinuation) -> usize {
1275    let mut aliases = HashMap::<CpsValueId, CpsValueId>::new();
1276    let mut tuples = HashMap::<CpsValueId, Vec<CpsValueId>>::new();
1277    let mut scalar_values = HashSet::<CpsValueId>::new();
1278    let mut stmts = Vec::with_capacity(continuation.stmts.len());
1279    let mut count = 0;
1280
1281    for stmt in continuation.stmts.drain(..) {
1282        let stmt = substitute_stmt_values(stmt, &aliases);
1283        match stmt {
1284            CpsStmt::Tuple { dest, items } => {
1285                tuples.insert(dest, items.clone());
1286                stmts.push(CpsStmt::Tuple { dest, items });
1287            }
1288            CpsStmt::TupleGet { dest, tuple, index } => {
1289                if let Some(items) = tuples.get(&tuple) {
1290                    if let Some(value) = items.get(index).copied() {
1291                        let value = resolve_alias(value, &aliases);
1292                        if scalar_values.contains(&value) {
1293                            aliases.insert(dest, value);
1294                            scalar_values.insert(dest);
1295                            count += 1;
1296                            continue;
1297                        }
1298                    }
1299                }
1300                tuples.remove(&dest);
1301                stmts.push(CpsStmt::TupleGet { dest, tuple, index });
1302            }
1303            stmt => {
1304                if let Some(dest) = stmt_dest(&stmt) {
1305                    tuples.remove(&dest);
1306                    if stmt_produces_scalar_value(&stmt) {
1307                        scalar_values.insert(dest);
1308                    }
1309                }
1310                stmts.push(stmt);
1311            }
1312        }
1313    }
1314
1315    continuation.terminator =
1316        substitute_terminator_values(continuation.terminator.clone(), &aliases);
1317    continuation.stmts = stmts;
1318    count
1319}
1320
1321fn stmt_produces_scalar_value(stmt: &CpsStmt) -> bool {
1322    matches!(
1323        stmt,
1324        CpsStmt::Literal { .. }
1325            | CpsStmt::RecordHasField { .. }
1326            | CpsStmt::VariantTagEq { .. }
1327            | CpsStmt::Primitive {
1328                op: typed_ir::PrimitiveOp::BoolNot
1329                    | typed_ir::PrimitiveOp::BoolEq
1330                    | typed_ir::PrimitiveOp::IntAdd
1331                    | typed_ir::PrimitiveOp::IntSub
1332                    | typed_ir::PrimitiveOp::IntMul
1333                    | typed_ir::PrimitiveOp::IntEq
1334                    | typed_ir::PrimitiveOp::IntLt
1335                    | typed_ir::PrimitiveOp::IntLe
1336                    | typed_ir::PrimitiveOp::IntGt
1337                    | typed_ir::PrimitiveOp::IntGe
1338                    | typed_ir::PrimitiveOp::FloatAdd
1339                    | typed_ir::PrimitiveOp::FloatSub
1340                    | typed_ir::PrimitiveOp::FloatMul
1341                    | typed_ir::PrimitiveOp::FloatEq
1342                    | typed_ir::PrimitiveOp::FloatLt
1343                    | typed_ir::PrimitiveOp::FloatLe
1344                    | typed_ir::PrimitiveOp::FloatGt
1345                    | typed_ir::PrimitiveOp::FloatGe,
1346                ..
1347            }
1348    )
1349}
1350
1351fn resolve_alias(mut value: CpsValueId, aliases: &HashMap<CpsValueId, CpsValueId>) -> CpsValueId {
1352    let mut seen = HashSet::new();
1353    while let Some(next) = aliases.get(&value).copied() {
1354        if !seen.insert(value) {
1355            break;
1356        }
1357        value = next;
1358    }
1359    value
1360}
1361
1362#[derive(Debug, Clone, PartialEq, Eq)]
1363struct PureDirectInline {
1364    params: Vec<CpsValueId>,
1365    stmts: Vec<CpsStmt>,
1366    result: CpsValueId,
1367}
1368
1369#[derive(Debug, Clone, PartialEq, Eq)]
1370struct DirectStyleIsland {
1371    continuations: Vec<CpsContinuationId>,
1372}
1373
1374fn direct_style_islands(function: &CpsReprAbiFunction) -> Vec<DirectStyleIsland> {
1375    let candidates = function
1376        .continuations
1377        .iter()
1378        .filter(|continuation| direct_style_candidate(continuation))
1379        .map(|continuation| continuation.id)
1380        .collect::<HashSet<_>>();
1381    if candidates.is_empty() {
1382        return Vec::new();
1383    }
1384
1385    let continuations = function
1386        .continuations
1387        .iter()
1388        .map(|continuation| (continuation.id, continuation))
1389        .collect::<HashMap<_, _>>();
1390    let mut visited = HashSet::new();
1391    let mut islands = Vec::new();
1392
1393    for start in candidates.iter().copied() {
1394        if visited.contains(&start) {
1395            continue;
1396        }
1397        let mut queue = VecDeque::from([start]);
1398        let mut island = Vec::new();
1399        visited.insert(start);
1400
1401        while let Some(id) = queue.pop_front() {
1402            island.push(id);
1403            let Some(continuation) = continuations.get(&id) else {
1404                continue;
1405            };
1406            for successor in direct_style_successors(&continuation.terminator) {
1407                if candidates.contains(&successor) && visited.insert(successor) {
1408                    queue.push_back(successor);
1409                }
1410            }
1411        }
1412
1413        island.sort();
1414        islands.push(DirectStyleIsland {
1415            continuations: island,
1416        });
1417    }
1418
1419    islands.sort_by_key(|island| island.continuations.first().copied());
1420    islands
1421}
1422
1423fn direct_style_candidate(continuation: &CpsReprAbiContinuation) -> bool {
1424    if !continuation.environment.is_empty() {
1425        return false;
1426    }
1427    continuation.stmts.iter().all(direct_style_stmt)
1428        && matches!(
1429            continuation.terminator,
1430            CpsTerminator::Return(_)
1431                | CpsTerminator::Continue { .. }
1432                | CpsTerminator::Branch { .. }
1433        )
1434}
1435
1436fn direct_style_stmt(stmt: &CpsStmt) -> bool {
1437    matches!(
1438        stmt,
1439        CpsStmt::Literal { .. }
1440            | CpsStmt::Tuple { .. }
1441            | CpsStmt::Record { .. }
1442            | CpsStmt::RecordWithoutFields { .. }
1443            | CpsStmt::Variant { .. }
1444            | CpsStmt::Select { .. }
1445            | CpsStmt::SelectWithDefault { .. }
1446            | CpsStmt::RecordHasField { .. }
1447            | CpsStmt::TupleGet { .. }
1448            | CpsStmt::VariantTagEq { .. }
1449            | CpsStmt::VariantPayload { .. }
1450            | CpsStmt::Primitive { .. }
1451            | CpsStmt::DirectCall { .. }
1452    )
1453}
1454
1455fn direct_style_successors(terminator: &CpsTerminator) -> Vec<CpsContinuationId> {
1456    match terminator {
1457        CpsTerminator::Continue { target, .. } => vec![*target],
1458        CpsTerminator::Branch {
1459            then_cont,
1460            else_cont,
1461            ..
1462        } => vec![*then_cont, *else_cont],
1463        CpsTerminator::Return(_)
1464        | CpsTerminator::Perform { .. }
1465        | CpsTerminator::EffectfulCall { .. }
1466        | CpsTerminator::EffectfulApply { .. }
1467        | CpsTerminator::EffectfulForce { .. } => Vec::new(),
1468    }
1469}
1470
1471fn eliminate_dead_pure_statements_in_continuation(
1472    continuation: &mut CpsReprAbiContinuation,
1473    captured_values: &HashSet<CpsValueId>,
1474) -> usize {
1475    let mut live = terminator_values(&continuation.terminator)
1476        .into_iter()
1477        .collect::<HashSet<_>>();
1478    live.extend(captured_values.iter().copied());
1479    let mut kept = Vec::with_capacity(continuation.stmts.len());
1480    let mut removed = 0;
1481
1482    for stmt in continuation.stmts.iter().rev() {
1483        let dest = stmt_dest(stmt);
1484        if dest.is_some_and(|dest| !live.contains(&dest)) && stmt_is_pure(stmt) {
1485            removed += 1;
1486            continue;
1487        }
1488
1489        if let Some(dest) = dest {
1490            live.remove(&dest);
1491        }
1492        live.extend(stmt_operands(stmt));
1493        kept.push(stmt.clone());
1494    }
1495
1496    kept.reverse();
1497    continuation.stmts = kept;
1498    removed
1499}
1500
1501fn function_captured_values(function: &CpsReprAbiFunction) -> HashSet<CpsValueId> {
1502    function
1503        .continuations
1504        .iter()
1505        .flat_map(|continuation| continuation.environment.iter().map(|slot| slot.value))
1506        .collect()
1507}
1508
1509fn inline_candidates(
1510    function: &CpsReprAbiFunction,
1511) -> HashMap<CpsContinuationId, CpsReprAbiContinuation> {
1512    let references = continuation_references(function);
1513    let protected = protected_continuations(function);
1514    function
1515        .continuations
1516        .iter()
1517        .filter(|continuation| {
1518            if continuation.shot_kind != CpsShotKind::OneShot {
1519                return false;
1520            }
1521            if !continuation.environment.is_empty() {
1522                return false;
1523            }
1524            if continuation.stmts.len() > 12 {
1525                return false;
1526            }
1527            references
1528                .get(&continuation.id)
1529                .is_some_and(|reference| reference.total == 1 && reference.continue_calls == 1)
1530        })
1531        .filter(|continuation| !protected.contains(&continuation.id))
1532        .map(|continuation| (continuation.id, continuation.clone()))
1533        .collect()
1534}
1535
1536fn inline_continuation_call_at(
1537    function: &mut CpsReprAbiFunction,
1538    index: usize,
1539    candidates: &HashMap<CpsContinuationId, CpsReprAbiContinuation>,
1540) -> usize {
1541    let continuation = &mut function.continuations[index];
1542    let CpsTerminator::Continue { target, args } = &continuation.terminator else {
1543        return 0;
1544    };
1545    let Some(target_continuation) = candidates.get(target) else {
1546        return 0;
1547    };
1548    if target_continuation.id == continuation.id {
1549        return 0;
1550    }
1551    if target_continuation.params.len() != args.len() {
1552        return 0;
1553    }
1554
1555    let substitution = target_continuation
1556        .params
1557        .iter()
1558        .zip(args.iter().copied())
1559        .map(|(param, arg)| (param.value, arg))
1560        .collect::<HashMap<_, _>>();
1561    continuation.stmts.extend(
1562        target_continuation
1563            .stmts
1564            .iter()
1565            .cloned()
1566            .map(|stmt| substitute_stmt_values(stmt, &substitution)),
1567    );
1568    continuation.terminator =
1569        substitute_terminator_values(target_continuation.terminator.clone(), &substitution);
1570    1
1571}
1572
1573#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
1574struct ContinuationReferenceCount {
1575    total: usize,
1576    continue_calls: usize,
1577}
1578
1579fn continuation_references(
1580    function: &CpsReprAbiFunction,
1581) -> HashMap<CpsContinuationId, ContinuationReferenceCount> {
1582    let mut references = HashMap::new();
1583    for continuation in &function.continuations {
1584        for stmt in &continuation.stmts {
1585            collect_stmt_reference_counts(stmt, &mut references);
1586        }
1587        collect_terminator_reference_counts(&continuation.terminator, &mut references);
1588    }
1589    references
1590}
1591
1592fn protected_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
1593    let mut protected = HashSet::new();
1594    protected.insert(function.entry);
1595    for handler in &function.handlers {
1596        for arm in &handler.arms {
1597            protected.insert(arm.entry);
1598        }
1599    }
1600    for continuation in &function.continuations {
1601        for stmt in &continuation.stmts {
1602            collect_protected_stmt_continuations(stmt, &mut protected);
1603        }
1604    }
1605    protected
1606}
1607
1608fn collect_stmt_reference_counts(
1609    stmt: &CpsStmt,
1610    references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1611) {
1612    match stmt {
1613        CpsStmt::MakeThunk { entry, .. }
1614        | CpsStmt::MakeClosure { entry, .. }
1615        | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1616            count_reference(*entry, references, false);
1617        }
1618        CpsStmt::InstallHandler {
1619            value,
1620            escape,
1621            envs,
1622            ..
1623        } => {
1624            count_reference(*value, references, false);
1625            count_reference(*escape, references, false);
1626            for env in envs {
1627                count_reference(env.entry, references, false);
1628            }
1629        }
1630        CpsStmt::ResumeWithHandler { envs, .. } => {
1631            for env in envs {
1632                count_reference(env.entry, references, false);
1633            }
1634        }
1635        CpsStmt::Literal { .. }
1636        | CpsStmt::FreshGuard { .. }
1637        | CpsStmt::PeekGuard { .. }
1638        | CpsStmt::FindGuard { .. }
1639        | CpsStmt::AddThunkBoundary { .. }
1640        | CpsStmt::ForceThunk { .. }
1641        | CpsStmt::Tuple { .. }
1642        | CpsStmt::Record { .. }
1643        | CpsStmt::RecordWithoutFields { .. }
1644        | CpsStmt::Variant { .. }
1645        | CpsStmt::Select { .. }
1646        | CpsStmt::SelectWithDefault { .. }
1647        | CpsStmt::RecordHasField { .. }
1648        | CpsStmt::TupleGet { .. }
1649        | CpsStmt::VariantTagEq { .. }
1650        | CpsStmt::VariantPayload { .. }
1651        | CpsStmt::Primitive { .. }
1652        | CpsStmt::DirectCall { .. }
1653        | CpsStmt::ApplyClosure { .. }
1654        | CpsStmt::CloneContinuation { .. }
1655        | CpsStmt::Resume { .. }
1656        | CpsStmt::UninstallHandler { .. } => {}
1657    }
1658}
1659
1660fn collect_terminator_reference_counts(
1661    terminator: &CpsTerminator,
1662    references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1663) {
1664    match terminator {
1665        CpsTerminator::Continue { target, .. } => count_reference(*target, references, true),
1666        CpsTerminator::Branch {
1667            then_cont,
1668            else_cont,
1669            ..
1670        } => {
1671            count_reference(*then_cont, references, false);
1672            count_reference(*else_cont, references, false);
1673        }
1674        CpsTerminator::Perform { resume, .. }
1675        | CpsTerminator::EffectfulCall { resume, .. }
1676        | CpsTerminator::EffectfulApply { resume, .. }
1677        | CpsTerminator::EffectfulForce { resume, .. } => {
1678            count_reference(*resume, references, false)
1679        }
1680        CpsTerminator::Return(_) => {}
1681    }
1682}
1683
1684fn collect_protected_stmt_continuations(
1685    stmt: &CpsStmt,
1686    protected: &mut HashSet<CpsContinuationId>,
1687) {
1688    match stmt {
1689        CpsStmt::MakeThunk { entry, .. }
1690        | CpsStmt::MakeClosure { entry, .. }
1691        | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1692            protected.insert(*entry);
1693        }
1694        CpsStmt::InstallHandler {
1695            value,
1696            escape,
1697            envs,
1698            ..
1699        } => {
1700            protected.insert(*value);
1701            protected.insert(*escape);
1702            for env in envs {
1703                protected.insert(env.entry);
1704            }
1705        }
1706        CpsStmt::ResumeWithHandler { envs, .. } => {
1707            for env in envs {
1708                protected.insert(env.entry);
1709            }
1710        }
1711        CpsStmt::Literal { .. }
1712        | CpsStmt::FreshGuard { .. }
1713        | CpsStmt::PeekGuard { .. }
1714        | CpsStmt::FindGuard { .. }
1715        | CpsStmt::AddThunkBoundary { .. }
1716        | CpsStmt::ForceThunk { .. }
1717        | CpsStmt::Tuple { .. }
1718        | CpsStmt::Record { .. }
1719        | CpsStmt::RecordWithoutFields { .. }
1720        | CpsStmt::Variant { .. }
1721        | CpsStmt::Select { .. }
1722        | CpsStmt::SelectWithDefault { .. }
1723        | CpsStmt::RecordHasField { .. }
1724        | CpsStmt::TupleGet { .. }
1725        | CpsStmt::VariantTagEq { .. }
1726        | CpsStmt::VariantPayload { .. }
1727        | CpsStmt::Primitive { .. }
1728        | CpsStmt::DirectCall { .. }
1729        | CpsStmt::ApplyClosure { .. }
1730        | CpsStmt::CloneContinuation { .. }
1731        | CpsStmt::Resume { .. }
1732        | CpsStmt::UninstallHandler { .. } => {}
1733    }
1734}
1735
1736fn count_reference(
1737    id: CpsContinuationId,
1738    references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1739    is_continue_call: bool,
1740) {
1741    let reference = references.entry(id).or_default();
1742    reference.total += 1;
1743    if is_continue_call {
1744        reference.continue_calls += 1;
1745    }
1746}
1747
1748fn stmt_is_pure(stmt: &CpsStmt) -> bool {
1749    matches!(
1750        stmt,
1751        CpsStmt::Literal { .. }
1752            | CpsStmt::MakeThunk { .. }
1753            | CpsStmt::AddThunkBoundary { .. }
1754            | CpsStmt::MakeClosure { .. }
1755            | CpsStmt::MakeRecursiveClosure { .. }
1756            | CpsStmt::Tuple { .. }
1757            | CpsStmt::Record { .. }
1758            | CpsStmt::RecordWithoutFields { .. }
1759            | CpsStmt::Variant { .. }
1760            | CpsStmt::Select { .. }
1761            | CpsStmt::SelectWithDefault { .. }
1762            | CpsStmt::RecordHasField { .. }
1763            | CpsStmt::TupleGet { .. }
1764            | CpsStmt::VariantTagEq { .. }
1765            | CpsStmt::Primitive {
1766                op: typed_ir::PrimitiveOp::BoolNot
1767                    | typed_ir::PrimitiveOp::BoolEq
1768                    | typed_ir::PrimitiveOp::IntAdd
1769                    | typed_ir::PrimitiveOp::IntSub
1770                    | typed_ir::PrimitiveOp::IntMul
1771                    | typed_ir::PrimitiveOp::IntEq
1772                    | typed_ir::PrimitiveOp::IntLt
1773                    | typed_ir::PrimitiveOp::IntLe
1774                    | typed_ir::PrimitiveOp::IntGt
1775                    | typed_ir::PrimitiveOp::IntGe
1776                    | typed_ir::PrimitiveOp::IntToString
1777                    | typed_ir::PrimitiveOp::IntToHex
1778                    | typed_ir::PrimitiveOp::IntToUpperHex
1779                    | typed_ir::PrimitiveOp::FloatAdd
1780                    | typed_ir::PrimitiveOp::FloatSub
1781                    | typed_ir::PrimitiveOp::FloatMul
1782                    | typed_ir::PrimitiveOp::FloatEq
1783                    | typed_ir::PrimitiveOp::FloatLt
1784                    | typed_ir::PrimitiveOp::FloatLe
1785                    | typed_ir::PrimitiveOp::FloatGt
1786                    | typed_ir::PrimitiveOp::FloatGe
1787                    | typed_ir::PrimitiveOp::FloatToString
1788                    | typed_ir::PrimitiveOp::BoolToString
1789                    | typed_ir::PrimitiveOp::StringConcat
1790                    | typed_ir::PrimitiveOp::StringLen
1791                    | typed_ir::PrimitiveOp::StringEq,
1792                ..
1793            }
1794    )
1795}
1796
1797fn stmt_dest(stmt: &CpsStmt) -> Option<CpsValueId> {
1798    match stmt {
1799        CpsStmt::Literal { dest, .. }
1800        | CpsStmt::FreshGuard { dest, .. }
1801        | CpsStmt::PeekGuard { dest }
1802        | CpsStmt::FindGuard { dest, .. }
1803        | CpsStmt::MakeThunk { dest, .. }
1804        | CpsStmt::AddThunkBoundary { dest, .. }
1805        | CpsStmt::MakeClosure { dest, .. }
1806        | CpsStmt::MakeRecursiveClosure { dest, .. }
1807        | CpsStmt::ForceThunk { dest, .. }
1808        | CpsStmt::Tuple { dest, .. }
1809        | CpsStmt::Record { dest, .. }
1810        | CpsStmt::RecordWithoutFields { dest, .. }
1811        | CpsStmt::Variant { dest, .. }
1812        | CpsStmt::Select { dest, .. }
1813        | CpsStmt::SelectWithDefault { dest, .. }
1814        | CpsStmt::RecordHasField { dest, .. }
1815        | CpsStmt::TupleGet { dest, .. }
1816        | CpsStmt::VariantTagEq { dest, .. }
1817        | CpsStmt::VariantPayload { dest, .. }
1818        | CpsStmt::Primitive { dest, .. }
1819        | CpsStmt::DirectCall { dest, .. }
1820        | CpsStmt::ApplyClosure { dest, .. }
1821        | CpsStmt::CloneContinuation { dest, .. }
1822        | CpsStmt::Resume { dest, .. }
1823        | CpsStmt::ResumeWithHandler { dest, .. } => Some(*dest),
1824        CpsStmt::InstallHandler { .. } | CpsStmt::UninstallHandler { .. } => None,
1825    }
1826}
1827
1828fn stmt_operands(stmt: &CpsStmt) -> Vec<CpsValueId> {
1829    match stmt {
1830        CpsStmt::FindGuard { guard, .. } => vec![*guard],
1831        CpsStmt::AddThunkBoundary { thunk, guard, .. } => vec![*thunk, *guard],
1832        CpsStmt::ForceThunk { thunk, .. } => vec![*thunk],
1833        CpsStmt::Tuple { items, .. } => items.clone(),
1834        CpsStmt::Record { base, fields, .. } => base
1835            .iter()
1836            .copied()
1837            .chain(fields.iter().map(|field| field.value))
1838            .collect(),
1839        CpsStmt::RecordWithoutFields { base, .. } => vec![*base],
1840        CpsStmt::Variant { value, .. } => value.iter().copied().collect(),
1841        CpsStmt::Select { base, .. } | CpsStmt::RecordHasField { base, .. } => vec![*base],
1842        CpsStmt::SelectWithDefault { base, default, .. } => vec![*base, *default],
1843        CpsStmt::TupleGet { tuple, .. } => vec![*tuple],
1844        CpsStmt::VariantTagEq { variant, .. } | CpsStmt::VariantPayload { variant, .. } => {
1845            vec![*variant]
1846        }
1847        CpsStmt::Primitive { args, .. } | CpsStmt::DirectCall { args, .. } => args.clone(),
1848        CpsStmt::ApplyClosure { closure, arg, .. } => vec![*closure, *arg],
1849        CpsStmt::CloneContinuation { source, .. } => vec![*source],
1850        CpsStmt::Resume {
1851            resumption, arg, ..
1852        } => vec![*resumption, *arg],
1853        CpsStmt::ResumeWithHandler {
1854            resumption,
1855            arg,
1856            envs,
1857            ..
1858        } => std::iter::once(*resumption)
1859            .chain(std::iter::once(*arg))
1860            .chain(envs.iter().flat_map(|env| env.values.iter().copied()))
1861            .collect(),
1862        CpsStmt::InstallHandler { envs, .. } => envs
1863            .iter()
1864            .flat_map(|env| env.values.iter().copied())
1865            .collect(),
1866        CpsStmt::Literal { .. }
1867        | CpsStmt::FreshGuard { .. }
1868        | CpsStmt::PeekGuard { .. }
1869        | CpsStmt::MakeThunk { .. }
1870        | CpsStmt::MakeClosure { .. }
1871        | CpsStmt::MakeRecursiveClosure { .. }
1872        | CpsStmt::UninstallHandler { .. } => Vec::new(),
1873    }
1874}
1875
1876fn terminator_values(terminator: &CpsTerminator) -> Vec<CpsValueId> {
1877    match terminator {
1878        CpsTerminator::Return(value) => vec![*value],
1879        CpsTerminator::Continue { args, .. } => args.clone(),
1880        CpsTerminator::Branch { cond, .. } => vec![*cond],
1881        CpsTerminator::Perform {
1882            payload, blocked, ..
1883        } => std::iter::once(*payload)
1884            .chain(blocked.iter().copied())
1885            .collect(),
1886        CpsTerminator::EffectfulCall { args, .. } => args.clone(),
1887        CpsTerminator::EffectfulApply { closure, arg, .. } => vec![*closure, *arg],
1888        CpsTerminator::EffectfulForce { thunk, .. } => vec![*thunk],
1889    }
1890}
1891
1892fn reachable_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
1893    let continuations = function
1894        .continuations
1895        .iter()
1896        .map(|continuation| (continuation.id, continuation))
1897        .collect::<HashMap<_, _>>();
1898    let mut reachable = HashSet::new();
1899    let mut work = VecDeque::new();
1900
1901    push_reachable(function.entry, &mut reachable, &mut work);
1902    for handler in &function.handlers {
1903        for arm in &handler.arms {
1904            push_reachable(arm.entry, &mut reachable, &mut work);
1905        }
1906    }
1907
1908    while let Some(id) = work.pop_front() {
1909        let Some(continuation) = continuations.get(&id) else {
1910            continue;
1911        };
1912        for stmt in &continuation.stmts {
1913            collect_stmt_continuations(stmt, &mut reachable, &mut work);
1914        }
1915        collect_terminator_continuations(&continuation.terminator, &mut reachable, &mut work);
1916    }
1917
1918    reachable
1919}
1920
1921fn push_reachable(
1922    id: CpsContinuationId,
1923    reachable: &mut HashSet<CpsContinuationId>,
1924    work: &mut VecDeque<CpsContinuationId>,
1925) {
1926    if reachable.insert(id) {
1927        work.push_back(id);
1928    }
1929}
1930
1931fn collect_stmt_continuations(
1932    stmt: &CpsStmt,
1933    reachable: &mut HashSet<CpsContinuationId>,
1934    work: &mut VecDeque<CpsContinuationId>,
1935) {
1936    match stmt {
1937        CpsStmt::MakeThunk { entry, .. }
1938        | CpsStmt::MakeClosure { entry, .. }
1939        | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1940            push_reachable(*entry, reachable, work);
1941        }
1942        CpsStmt::InstallHandler {
1943            value,
1944            escape,
1945            envs,
1946            ..
1947        } => {
1948            push_reachable(*value, reachable, work);
1949            push_reachable(*escape, reachable, work);
1950            for env in envs {
1951                push_reachable(env.entry, reachable, work);
1952            }
1953        }
1954        CpsStmt::ResumeWithHandler { envs, .. } => {
1955            for env in envs {
1956                push_reachable(env.entry, reachable, work);
1957            }
1958        }
1959        CpsStmt::Literal { .. }
1960        | CpsStmt::FreshGuard { .. }
1961        | CpsStmt::PeekGuard { .. }
1962        | CpsStmt::FindGuard { .. }
1963        | CpsStmt::AddThunkBoundary { .. }
1964        | CpsStmt::ForceThunk { .. }
1965        | CpsStmt::Tuple { .. }
1966        | CpsStmt::Record { .. }
1967        | CpsStmt::RecordWithoutFields { .. }
1968        | CpsStmt::Variant { .. }
1969        | CpsStmt::Select { .. }
1970        | CpsStmt::SelectWithDefault { .. }
1971        | CpsStmt::RecordHasField { .. }
1972        | CpsStmt::TupleGet { .. }
1973        | CpsStmt::VariantTagEq { .. }
1974        | CpsStmt::VariantPayload { .. }
1975        | CpsStmt::Primitive { .. }
1976        | CpsStmt::DirectCall { .. }
1977        | CpsStmt::ApplyClosure { .. }
1978        | CpsStmt::CloneContinuation { .. }
1979        | CpsStmt::Resume { .. }
1980        | CpsStmt::UninstallHandler { .. } => {}
1981    }
1982}
1983
1984fn collect_terminator_continuations(
1985    terminator: &CpsTerminator,
1986    reachable: &mut HashSet<CpsContinuationId>,
1987    work: &mut VecDeque<CpsContinuationId>,
1988) {
1989    match terminator {
1990        CpsTerminator::Continue { target, .. } => push_reachable(*target, reachable, work),
1991        CpsTerminator::Branch {
1992            then_cont,
1993            else_cont,
1994            ..
1995        } => {
1996            push_reachable(*then_cont, reachable, work);
1997            push_reachable(*else_cont, reachable, work);
1998        }
1999        CpsTerminator::Perform { resume, .. }
2000        | CpsTerminator::EffectfulCall { resume, .. }
2001        | CpsTerminator::EffectfulApply { resume, .. }
2002        | CpsTerminator::EffectfulForce { resume, .. } => push_reachable(*resume, reachable, work),
2003        CpsTerminator::Return(_) => {}
2004    }
2005}
2006
2007fn forwarding_continuations(
2008    function: &CpsReprAbiFunction,
2009) -> HashMap<CpsContinuationId, ForwardingContinuation> {
2010    let mut forwarders = HashMap::new();
2011    for continuation in &function.continuations {
2012        if !continuation.stmts.is_empty() || !continuation.environment.is_empty() {
2013            continue;
2014        }
2015        let CpsTerminator::Continue { target, args } = &continuation.terminator else {
2016            continue;
2017        };
2018        if *target == continuation.id {
2019            continue;
2020        }
2021        if args
2022            .iter()
2023            .all(|arg| continuation.params.iter().any(|param| param.value == *arg))
2024        {
2025            forwarders.insert(
2026                continuation.id,
2027                ForwardingContinuation {
2028                    params: continuation
2029                        .params
2030                        .iter()
2031                        .map(|param| param.value)
2032                        .collect(),
2033                    target: *target,
2034                    args: args.clone(),
2035                },
2036            );
2037        }
2038    }
2039    forwarders
2040}
2041
2042fn returning_continuations(
2043    function: &CpsReprAbiFunction,
2044) -> HashMap<CpsContinuationId, ReturningContinuation> {
2045    let mut returners = HashMap::new();
2046    for continuation in &function.continuations {
2047        if !continuation.stmts.is_empty() || !continuation.environment.is_empty() {
2048            continue;
2049        }
2050        let CpsTerminator::Return(value) = continuation.terminator else {
2051            continue;
2052        };
2053        if let Some(param_index) = continuation
2054            .params
2055            .iter()
2056            .position(|param| param.value == value)
2057        {
2058            returners.insert(continuation.id, ReturningContinuation { param_index });
2059        }
2060    }
2061    returners
2062}
2063
2064fn rewrite_terminator_forwarders(
2065    terminator: &mut CpsTerminator,
2066    forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2067) -> usize {
2068    match terminator {
2069        CpsTerminator::Continue { target, args } => {
2070            rewrite_continuation_call(target, args, forwarders)
2071        }
2072        CpsTerminator::Perform { resume, .. }
2073        | CpsTerminator::EffectfulCall { resume, .. }
2074        | CpsTerminator::EffectfulApply { resume, .. }
2075        | CpsTerminator::EffectfulForce { resume, .. } => {
2076            let mut args = Vec::new();
2077            rewrite_resume_target(resume, &mut args, forwarders)
2078        }
2079        CpsTerminator::Branch {
2080            then_cont,
2081            else_cont,
2082            ..
2083        } => {
2084            let mut count = 0;
2085            let mut args = Vec::new();
2086            count += rewrite_resume_target(then_cont, &mut args, forwarders);
2087            count += rewrite_resume_target(else_cont, &mut args, forwarders);
2088            count
2089        }
2090        CpsTerminator::Return(_) => 0,
2091    }
2092}
2093
2094fn rewrite_terminator_returners(
2095    terminator: &mut CpsTerminator,
2096    returners: &HashMap<CpsContinuationId, ReturningContinuation>,
2097) -> usize {
2098    let CpsTerminator::Continue { target, args } = terminator else {
2099        return 0;
2100    };
2101    let Some(returner) = returners.get(target) else {
2102        return 0;
2103    };
2104    let Some(value) = args.get(returner.param_index).copied() else {
2105        return 0;
2106    };
2107    *terminator = CpsTerminator::Return(value);
2108    1
2109}
2110
2111fn rewrite_continuation_call(
2112    target: &mut CpsContinuationId,
2113    args: &mut Vec<CpsValueId>,
2114    forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2115) -> usize {
2116    let mut count = 0;
2117    while let Some(forwarder) = forwarders.get(target) {
2118        let Some(remapped) = forwarder.remap_args(args) else {
2119            break;
2120        };
2121        *target = forwarder.target;
2122        *args = remapped;
2123        count += 1;
2124    }
2125    count
2126}
2127
2128fn rewrite_resume_target(
2129    target: &mut CpsContinuationId,
2130    args: &mut Vec<CpsValueId>,
2131    forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2132) -> usize {
2133    let mut count = 0;
2134    while let Some(forwarder) = forwarders.get(target) {
2135        if !forwarder.params.is_empty() {
2136            break;
2137        }
2138        if !forwarder.args.is_empty() {
2139            break;
2140        }
2141        *target = forwarder.target;
2142        args.clear();
2143        count += 1;
2144    }
2145    count
2146}
2147
2148fn substitute_stmt_values(
2149    stmt: CpsStmt,
2150    substitution: &HashMap<CpsValueId, CpsValueId>,
2151) -> CpsStmt {
2152    match stmt {
2153        CpsStmt::Literal { dest, literal } => CpsStmt::Literal { dest, literal },
2154        CpsStmt::FreshGuard { dest, var } => CpsStmt::FreshGuard { dest, var },
2155        CpsStmt::PeekGuard { dest } => CpsStmt::PeekGuard { dest },
2156        CpsStmt::FindGuard { dest, guard } => CpsStmt::FindGuard {
2157            dest,
2158            guard: subst_value(guard, substitution),
2159        },
2160        CpsStmt::MakeThunk { dest, entry } => CpsStmt::MakeThunk { dest, entry },
2161        CpsStmt::AddThunkBoundary {
2162            dest,
2163            thunk,
2164            guard,
2165            allowed,
2166            active,
2167        } => CpsStmt::AddThunkBoundary {
2168            dest,
2169            thunk: subst_value(thunk, substitution),
2170            guard: subst_value(guard, substitution),
2171            allowed,
2172            active,
2173        },
2174        CpsStmt::MakeClosure { dest, entry } => CpsStmt::MakeClosure { dest, entry },
2175        CpsStmt::MakeRecursiveClosure { dest, entry } => {
2176            CpsStmt::MakeRecursiveClosure { dest, entry }
2177        }
2178        CpsStmt::ForceThunk { dest, thunk } => CpsStmt::ForceThunk {
2179            dest,
2180            thunk: subst_value(thunk, substitution),
2181        },
2182        CpsStmt::Tuple { dest, items } => CpsStmt::Tuple {
2183            dest,
2184            items: subst_values(items, substitution),
2185        },
2186        CpsStmt::Record { dest, base, fields } => CpsStmt::Record {
2187            dest,
2188            base: base.map(|value| subst_value(value, substitution)),
2189            fields: fields
2190                .into_iter()
2191                .map(|field| CpsRecordField {
2192                    name: field.name,
2193                    value: subst_value(field.value, substitution),
2194                })
2195                .collect(),
2196        },
2197        CpsStmt::RecordWithoutFields { dest, base, fields } => CpsStmt::RecordWithoutFields {
2198            dest,
2199            base: subst_value(base, substitution),
2200            fields,
2201        },
2202        CpsStmt::Variant { dest, tag, value } => CpsStmt::Variant {
2203            dest,
2204            tag,
2205            value: value.map(|value| subst_value(value, substitution)),
2206        },
2207        CpsStmt::Select { dest, base, field } => CpsStmt::Select {
2208            dest,
2209            base: subst_value(base, substitution),
2210            field,
2211        },
2212        CpsStmt::SelectWithDefault {
2213            dest,
2214            base,
2215            field,
2216            default,
2217        } => CpsStmt::SelectWithDefault {
2218            dest,
2219            base: subst_value(base, substitution),
2220            field,
2221            default: subst_value(default, substitution),
2222        },
2223        CpsStmt::RecordHasField { dest, base, field } => CpsStmt::RecordHasField {
2224            dest,
2225            base: subst_value(base, substitution),
2226            field,
2227        },
2228        CpsStmt::TupleGet { dest, tuple, index } => CpsStmt::TupleGet {
2229            dest,
2230            tuple: subst_value(tuple, substitution),
2231            index,
2232        },
2233        CpsStmt::VariantTagEq { dest, variant, tag } => CpsStmt::VariantTagEq {
2234            dest,
2235            variant: subst_value(variant, substitution),
2236            tag,
2237        },
2238        CpsStmt::VariantPayload { dest, variant } => CpsStmt::VariantPayload {
2239            dest,
2240            variant: subst_value(variant, substitution),
2241        },
2242        CpsStmt::Primitive { dest, op, args } => CpsStmt::Primitive {
2243            dest,
2244            op,
2245            args: subst_values(args, substitution),
2246        },
2247        CpsStmt::DirectCall { dest, target, args } => CpsStmt::DirectCall {
2248            dest,
2249            target,
2250            args: subst_values(args, substitution),
2251        },
2252        CpsStmt::ApplyClosure { dest, closure, arg } => CpsStmt::ApplyClosure {
2253            dest,
2254            closure: subst_value(closure, substitution),
2255            arg: subst_value(arg, substitution),
2256        },
2257        CpsStmt::CloneContinuation { dest, source } => CpsStmt::CloneContinuation {
2258            dest,
2259            source: subst_value(source, substitution),
2260        },
2261        CpsStmt::Resume {
2262            dest,
2263            resumption,
2264            arg,
2265        } => CpsStmt::Resume {
2266            dest,
2267            resumption: subst_value(resumption, substitution),
2268            arg: subst_value(arg, substitution),
2269        },
2270        CpsStmt::ResumeWithHandler {
2271            dest,
2272            resumption,
2273            arg,
2274            handler,
2275            envs,
2276        } => CpsStmt::ResumeWithHandler {
2277            dest,
2278            resumption: subst_value(resumption, substitution),
2279            arg: subst_value(arg, substitution),
2280            handler,
2281            envs: subst_handler_envs(envs, substitution),
2282        },
2283        CpsStmt::InstallHandler {
2284            handler,
2285            envs,
2286            value,
2287            escape,
2288        } => CpsStmt::InstallHandler {
2289            handler,
2290            envs: subst_handler_envs(envs, substitution),
2291            value,
2292            escape,
2293        },
2294        CpsStmt::UninstallHandler { handler } => CpsStmt::UninstallHandler { handler },
2295    }
2296}
2297
2298fn substitute_terminator_values(
2299    terminator: CpsTerminator,
2300    substitution: &HashMap<CpsValueId, CpsValueId>,
2301) -> CpsTerminator {
2302    match terminator {
2303        CpsTerminator::Return(value) => CpsTerminator::Return(subst_value(value, substitution)),
2304        CpsTerminator::Continue { target, args } => CpsTerminator::Continue {
2305            target,
2306            args: subst_values(args, substitution),
2307        },
2308        CpsTerminator::Branch {
2309            cond,
2310            then_cont,
2311            else_cont,
2312        } => CpsTerminator::Branch {
2313            cond: subst_value(cond, substitution),
2314            then_cont,
2315            else_cont,
2316        },
2317        CpsTerminator::Perform {
2318            effect,
2319            payload,
2320            resume,
2321            handler,
2322            blocked,
2323        } => CpsTerminator::Perform {
2324            effect,
2325            payload: subst_value(payload, substitution),
2326            resume,
2327            handler,
2328            blocked: blocked.map(|value| subst_value(value, substitution)),
2329        },
2330        CpsTerminator::EffectfulCall {
2331            target,
2332            args,
2333            resume,
2334        } => CpsTerminator::EffectfulCall {
2335            target,
2336            args: subst_values(args, substitution),
2337            resume,
2338        },
2339        CpsTerminator::EffectfulApply {
2340            closure,
2341            arg,
2342            resume,
2343        } => CpsTerminator::EffectfulApply {
2344            closure: subst_value(closure, substitution),
2345            arg: subst_value(arg, substitution),
2346            resume,
2347        },
2348        CpsTerminator::EffectfulForce { thunk, resume } => CpsTerminator::EffectfulForce {
2349            thunk: subst_value(thunk, substitution),
2350            resume,
2351        },
2352    }
2353}
2354
2355fn subst_handler_envs(
2356    envs: Vec<CpsHandlerEnv>,
2357    substitution: &HashMap<CpsValueId, CpsValueId>,
2358) -> Vec<CpsHandlerEnv> {
2359    envs.into_iter()
2360        .map(|env| CpsHandlerEnv {
2361            entry: env.entry,
2362            values: subst_values(env.values, substitution),
2363            targets: subst_values(env.targets, substitution),
2364        })
2365        .collect()
2366}
2367
2368fn subst_values(
2369    values: Vec<CpsValueId>,
2370    substitution: &HashMap<CpsValueId, CpsValueId>,
2371) -> Vec<CpsValueId> {
2372    values
2373        .into_iter()
2374        .map(|value| subst_value(value, substitution))
2375        .collect()
2376}
2377
2378fn subst_value(value: CpsValueId, substitution: &HashMap<CpsValueId, CpsValueId>) -> CpsValueId {
2379    substitution.get(&value).copied().unwrap_or(value)
2380}
2381
2382#[derive(Debug, Clone, PartialEq, Eq)]
2383struct ForwardingContinuation {
2384    params: Vec<CpsValueId>,
2385    target: CpsContinuationId,
2386    args: Vec<CpsValueId>,
2387}
2388
2389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2390struct ReturningContinuation {
2391    param_index: usize,
2392}
2393
2394impl ForwardingContinuation {
2395    fn remap_args(&self, supplied_args: &[CpsValueId]) -> Option<Vec<CpsValueId>> {
2396        if supplied_args.len() != self.params.len() {
2397            return None;
2398        }
2399        self.args
2400            .iter()
2401            .map(|forwarded| {
2402                self.params
2403                    .iter()
2404                    .position(|param| param == forwarded)
2405                    .map(|index| supplied_args[index])
2406            })
2407            .collect()
2408    }
2409}
2410
2411impl CpsOptimizationProfile {
2412    fn record_optimized_size(&mut self, module: &CpsReprAbiModule) {
2413        self.optimized_continuations = module
2414            .functions
2415            .iter()
2416            .chain(&module.roots)
2417            .map(|function| function.continuations.len())
2418            .sum();
2419        self.optimized_statements = module
2420            .functions
2421            .iter()
2422            .chain(&module.roots)
2423            .flat_map(|function| &function.continuations)
2424            .map(|continuation| continuation.stmts.len())
2425            .sum();
2426    }
2427
2428    fn has_more_changes_than(self, before: Self) -> bool {
2429        self.forwarded_continuation_calls > before.forwarded_continuation_calls
2430            || self.returned_continuation_calls > before.returned_continuation_calls
2431            || self.folded_constant_branches > before.folded_constant_branches
2432            || self.rewritten_pure_effectful_calls > before.rewritten_pure_effectful_calls
2433            || self.reified_primitive_calls > before.reified_primitive_calls
2434            || self.reified_partial_closure_calls > before.reified_partial_closure_calls
2435            || self.reified_known_closure_parameter_calls
2436                > before.reified_known_closure_parameter_calls
2437            || self.removed_unused_continuation_params > before.removed_unused_continuation_params
2438            || self.folded_structural_projections > before.folded_structural_projections
2439            || self.inlined_pure_direct_calls > before.inlined_pure_direct_calls
2440            || self.inlined_continuation_calls > before.inlined_continuation_calls
2441            || self.removed_unreachable_continuations > before.removed_unreachable_continuations
2442            || self.removed_dead_pure_statements > before.removed_dead_pure_statements
2443    }
2444
2445    pub fn measure(module: &CpsReprAbiModule) -> Self {
2446        let functions = module.functions.len();
2447        let roots = module.roots.len();
2448        let continuations = module
2449            .functions
2450            .iter()
2451            .chain(&module.roots)
2452            .map(|function| function.continuations.len())
2453            .sum();
2454        let handlers = module
2455            .functions
2456            .iter()
2457            .chain(&module.roots)
2458            .map(|function| function.handlers.len())
2459            .sum();
2460        let statements = module
2461            .functions
2462            .iter()
2463            .chain(&module.roots)
2464            .flat_map(|function| &function.continuations)
2465            .map(|continuation| continuation.stmts.len())
2466            .sum();
2467
2468        Self {
2469            functions,
2470            roots,
2471            continuations,
2472            handlers,
2473            statements,
2474            optimized_continuations: continuations,
2475            optimized_statements: statements,
2476            passes_run: 0,
2477            forwarded_continuation_calls: 0,
2478            returned_continuation_calls: 0,
2479            folded_constant_branches: 0,
2480            rewritten_pure_effectful_calls: 0,
2481            reified_primitive_calls: 0,
2482            reified_partial_closure_calls: 0,
2483            reified_known_closure_parameter_calls: 0,
2484            removed_unused_continuation_params: 0,
2485            folded_structural_projections: 0,
2486            inlined_pure_direct_calls: 0,
2487            inlined_continuation_calls: 0,
2488            removed_unreachable_continuations: 0,
2489            removed_dead_pure_statements: 0,
2490            direct_style_islands: 0,
2491            direct_style_continuations: 0,
2492            changed: false,
2493        }
2494    }
2495}
2496
2497#[cfg(test)]
2498mod tests {
2499    use crate::cps_ir::{
2500        CpsContinuationId, CpsFunction, CpsLiteral, CpsModule, CpsShotKind, CpsStmt, CpsTerminator,
2501        CpsValueId,
2502    };
2503    use crate::cps_repr::lower_cps_repr_module;
2504    use crate::cps_repr_abi::lower_cps_repr_abi_module;
2505
2506    use super::*;
2507
2508    #[test]
2509    fn optimization_boundary_keeps_non_forwarding_module() {
2510        let abi = sample_abi_module();
2511        let optimized = optimize_cps_repr_abi_module(&abi);
2512
2513        assert_eq!(optimized.module, abi);
2514        assert_eq!(optimized.profile.roots, 1);
2515        assert_eq!(optimized.profile.continuations, 1);
2516        assert_eq!(optimized.profile.optimized_continuations, 1);
2517        assert_eq!(optimized.profile.statements, 1);
2518        assert_eq!(optimized.profile.optimized_statements, 1);
2519        assert_eq!(optimized.profile.passes_run, 17);
2520        assert_eq!(optimized.profile.forwarded_continuation_calls, 0);
2521        assert_eq!(optimized.profile.returned_continuation_calls, 0);
2522        assert_eq!(optimized.profile.folded_constant_branches, 0);
2523        assert_eq!(optimized.profile.rewritten_pure_effectful_calls, 0);
2524        assert_eq!(optimized.profile.reified_primitive_calls, 0);
2525        assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2526        assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 0);
2527        assert_eq!(optimized.profile.removed_unused_continuation_params, 0);
2528        assert_eq!(optimized.profile.folded_structural_projections, 0);
2529        assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2530        assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2531        assert_eq!(optimized.profile.removed_unreachable_continuations, 0);
2532        assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2533        assert_eq!(optimized.profile.direct_style_islands, 1);
2534        assert_eq!(optimized.profile.direct_style_continuations, 1);
2535        assert!(!optimized.profile.changed);
2536    }
2537
2538    #[test]
2539    fn rewrites_empty_continue_forwarder_calls() {
2540        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2541            functions: Vec::new(),
2542            roots: vec![CpsFunction {
2543                name: "root".to_string(),
2544                params: Vec::new(),
2545                entry: CpsContinuationId(0),
2546                handlers: Vec::new(),
2547                continuations: vec![
2548                    crate::cps_ir::CpsContinuation {
2549                        id: CpsContinuationId(0),
2550                        params: Vec::new(),
2551                        captures: Vec::new(),
2552                        shot_kind: CpsShotKind::OneShot,
2553                        stmts: vec![CpsStmt::Literal {
2554                            dest: CpsValueId(0),
2555                            literal: CpsLiteral::Int("42".to_string()),
2556                        }],
2557                        terminator: CpsTerminator::Continue {
2558                            target: CpsContinuationId(1),
2559                            args: vec![CpsValueId(0)],
2560                        },
2561                    },
2562                    crate::cps_ir::CpsContinuation {
2563                        id: CpsContinuationId(1),
2564                        params: vec![CpsValueId(1)],
2565                        captures: Vec::new(),
2566                        shot_kind: CpsShotKind::OneShot,
2567                        stmts: Vec::new(),
2568                        terminator: CpsTerminator::Continue {
2569                            target: CpsContinuationId(2),
2570                            args: vec![CpsValueId(1)],
2571                        },
2572                    },
2573                    crate::cps_ir::CpsContinuation {
2574                        id: CpsContinuationId(2),
2575                        params: vec![CpsValueId(2)],
2576                        captures: Vec::new(),
2577                        shot_kind: CpsShotKind::OneShot,
2578                        stmts: Vec::new(),
2579                        terminator: CpsTerminator::Return(CpsValueId(2)),
2580                    },
2581                ],
2582            }],
2583        }));
2584
2585        let optimized = optimize_cps_repr_abi_module(&abi);
2586        let entry = &optimized.module.roots[0].continuations[0];
2587
2588        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(0)));
2589        assert_eq!(optimized.profile.forwarded_continuation_calls, 1);
2590        assert_eq!(optimized.profile.returned_continuation_calls, 2);
2591        assert_eq!(optimized.profile.reified_primitive_calls, 0);
2592        assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2593        assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2594        assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2595        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
2596        assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2597        assert_eq!(optimized.profile.direct_style_islands, 1);
2598        assert_eq!(optimized.profile.direct_style_continuations, 1);
2599        assert!(optimized.profile.changed);
2600    }
2601
2602    #[test]
2603    fn rewrites_empty_returning_continuation_calls() {
2604        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2605            functions: Vec::new(),
2606            roots: vec![CpsFunction {
2607                name: "root".to_string(),
2608                params: Vec::new(),
2609                entry: CpsContinuationId(0),
2610                handlers: Vec::new(),
2611                continuations: vec![
2612                    crate::cps_ir::CpsContinuation {
2613                        id: CpsContinuationId(0),
2614                        params: Vec::new(),
2615                        captures: Vec::new(),
2616                        shot_kind: CpsShotKind::OneShot,
2617                        stmts: vec![CpsStmt::Literal {
2618                            dest: CpsValueId(0),
2619                            literal: CpsLiteral::Int("42".to_string()),
2620                        }],
2621                        terminator: CpsTerminator::Continue {
2622                            target: CpsContinuationId(1),
2623                            args: vec![CpsValueId(0)],
2624                        },
2625                    },
2626                    crate::cps_ir::CpsContinuation {
2627                        id: CpsContinuationId(1),
2628                        params: vec![CpsValueId(1)],
2629                        captures: Vec::new(),
2630                        shot_kind: CpsShotKind::OneShot,
2631                        stmts: Vec::new(),
2632                        terminator: CpsTerminator::Return(CpsValueId(1)),
2633                    },
2634                ],
2635            }],
2636        }));
2637
2638        let optimized = optimize_cps_repr_abi_module(&abi);
2639        let entry = &optimized.module.roots[0].continuations[0];
2640
2641        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(0)));
2642        assert_eq!(optimized.profile.returned_continuation_calls, 1);
2643        assert_eq!(optimized.profile.reified_primitive_calls, 0);
2644        assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2645        assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2646        assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2647        assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
2648        assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2649        assert_eq!(optimized.profile.direct_style_islands, 1);
2650        assert_eq!(optimized.profile.direct_style_continuations, 1);
2651        assert!(optimized.profile.changed);
2652    }
2653
2654    #[test]
2655    fn inlines_single_use_one_shot_continuations() {
2656        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2657            functions: Vec::new(),
2658            roots: vec![CpsFunction {
2659                name: "root".to_string(),
2660                params: Vec::new(),
2661                entry: CpsContinuationId(0),
2662                handlers: Vec::new(),
2663                continuations: vec![
2664                    crate::cps_ir::CpsContinuation {
2665                        id: CpsContinuationId(0),
2666                        params: Vec::new(),
2667                        captures: Vec::new(),
2668                        shot_kind: CpsShotKind::OneShot,
2669                        stmts: vec![CpsStmt::Literal {
2670                            dest: CpsValueId(0),
2671                            literal: CpsLiteral::Int("41".to_string()),
2672                        }],
2673                        terminator: CpsTerminator::Continue {
2674                            target: CpsContinuationId(1),
2675                            args: vec![CpsValueId(0)],
2676                        },
2677                    },
2678                    crate::cps_ir::CpsContinuation {
2679                        id: CpsContinuationId(1),
2680                        params: vec![CpsValueId(1)],
2681                        captures: Vec::new(),
2682                        shot_kind: CpsShotKind::OneShot,
2683                        stmts: vec![
2684                            CpsStmt::Literal {
2685                                dest: CpsValueId(2),
2686                                literal: CpsLiteral::Int("1".to_string()),
2687                            },
2688                            CpsStmt::Primitive {
2689                                dest: CpsValueId(3),
2690                                op: yulang_typed_ir::PrimitiveOp::IntAdd,
2691                                args: vec![CpsValueId(1), CpsValueId(2)],
2692                            },
2693                        ],
2694                        terminator: CpsTerminator::Return(CpsValueId(3)),
2695                    },
2696                ],
2697            }],
2698        }));
2699
2700        let optimized = optimize_cps_repr_abi_module(&abi);
2701        let root = &optimized.module.roots[0];
2702        let entry = &root.continuations[0];
2703
2704        assert_eq!(root.continuations.len(), 1);
2705        assert_eq!(entry.stmts.len(), 3);
2706        assert_eq!(
2707            entry.stmts[2],
2708            CpsStmt::Primitive {
2709                dest: CpsValueId(3),
2710                op: yulang_typed_ir::PrimitiveOp::IntAdd,
2711                args: vec![CpsValueId(0), CpsValueId(2)],
2712            }
2713        );
2714        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(3)));
2715        assert_eq!(optimized.profile.inlined_continuation_calls, 1);
2716        assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
2717        assert_eq!(optimized.profile.direct_style_islands, 1);
2718        assert_eq!(optimized.profile.direct_style_continuations, 1);
2719    }
2720
2721    #[test]
2722    fn reifies_direct_calls_to_primitive_wrappers() {
2723        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2724            functions: vec![CpsFunction {
2725                name: "add".to_string(),
2726                params: vec![CpsValueId(0), CpsValueId(1)],
2727                entry: CpsContinuationId(0),
2728                handlers: Vec::new(),
2729                continuations: vec![crate::cps_ir::CpsContinuation {
2730                    id: CpsContinuationId(0),
2731                    params: vec![CpsValueId(0), CpsValueId(1)],
2732                    captures: Vec::new(),
2733                    shot_kind: CpsShotKind::MultiShot,
2734                    stmts: vec![CpsStmt::Primitive {
2735                        dest: CpsValueId(2),
2736                        op: typed_ir::PrimitiveOp::IntAdd,
2737                        args: vec![CpsValueId(0), CpsValueId(1)],
2738                    }],
2739                    terminator: CpsTerminator::Return(CpsValueId(2)),
2740                }],
2741            }],
2742            roots: vec![CpsFunction {
2743                name: "root".to_string(),
2744                params: Vec::new(),
2745                entry: CpsContinuationId(0),
2746                handlers: Vec::new(),
2747                continuations: vec![crate::cps_ir::CpsContinuation {
2748                    id: CpsContinuationId(0),
2749                    params: Vec::new(),
2750                    captures: Vec::new(),
2751                    shot_kind: CpsShotKind::OneShot,
2752                    stmts: vec![
2753                        CpsStmt::Literal {
2754                            dest: CpsValueId(0),
2755                            literal: CpsLiteral::Int("1".to_string()),
2756                        },
2757                        CpsStmt::Literal {
2758                            dest: CpsValueId(1),
2759                            literal: CpsLiteral::Int("2".to_string()),
2760                        },
2761                        CpsStmt::DirectCall {
2762                            dest: CpsValueId(2),
2763                            target: "add".to_string(),
2764                            args: vec![CpsValueId(0), CpsValueId(1)],
2765                        },
2766                    ],
2767                    terminator: CpsTerminator::Return(CpsValueId(2)),
2768                }],
2769            }],
2770        }));
2771
2772        let optimized = optimize_cps_repr_abi_module(&abi);
2773        let entry = &optimized.module.roots[0].continuations[0];
2774
2775        assert_eq!(
2776            entry.stmts[2],
2777            CpsStmt::Primitive {
2778                dest: CpsValueId(2),
2779                op: typed_ir::PrimitiveOp::IntAdd,
2780                args: vec![CpsValueId(0), CpsValueId(1)],
2781            }
2782        );
2783        assert_eq!(optimized.profile.reified_primitive_calls, 1);
2784        assert_eq!(optimized.profile.direct_style_islands, 2);
2785        assert_eq!(optimized.profile.direct_style_continuations, 2);
2786    }
2787
2788    #[test]
2789    fn inlines_small_pure_direct_calls() {
2790        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2791            functions: vec![CpsFunction {
2792                name: "plus_one".to_string(),
2793                params: vec![CpsValueId(0)],
2794                entry: CpsContinuationId(0),
2795                handlers: Vec::new(),
2796                continuations: vec![crate::cps_ir::CpsContinuation {
2797                    id: CpsContinuationId(0),
2798                    params: vec![CpsValueId(0)],
2799                    captures: Vec::new(),
2800                    shot_kind: CpsShotKind::OneShot,
2801                    stmts: vec![
2802                        CpsStmt::Literal {
2803                            dest: CpsValueId(1),
2804                            literal: CpsLiteral::Int("1".to_string()),
2805                        },
2806                        CpsStmt::Primitive {
2807                            dest: CpsValueId(2),
2808                            op: typed_ir::PrimitiveOp::IntAdd,
2809                            args: vec![CpsValueId(0), CpsValueId(1)],
2810                        },
2811                    ],
2812                    terminator: CpsTerminator::Return(CpsValueId(2)),
2813                }],
2814            }],
2815            roots: vec![CpsFunction {
2816                name: "root".to_string(),
2817                params: Vec::new(),
2818                entry: CpsContinuationId(0),
2819                handlers: Vec::new(),
2820                continuations: vec![crate::cps_ir::CpsContinuation {
2821                    id: CpsContinuationId(0),
2822                    params: Vec::new(),
2823                    captures: Vec::new(),
2824                    shot_kind: CpsShotKind::OneShot,
2825                    stmts: vec![
2826                        CpsStmt::Literal {
2827                            dest: CpsValueId(0),
2828                            literal: CpsLiteral::Int("41".to_string()),
2829                        },
2830                        CpsStmt::DirectCall {
2831                            dest: CpsValueId(1),
2832                            target: "plus_one".to_string(),
2833                            args: vec![CpsValueId(0)],
2834                        },
2835                    ],
2836                    terminator: CpsTerminator::Return(CpsValueId(1)),
2837                }],
2838            }],
2839        }));
2840
2841        let optimized = optimize_cps_repr_abi_module(&abi);
2842        let entry = &optimized.module.roots[0].continuations[0];
2843
2844        assert_eq!(entry.stmts.len(), 3);
2845        assert_eq!(
2846            entry.stmts[1],
2847            CpsStmt::Literal {
2848                dest: CpsValueId(2),
2849                literal: CpsLiteral::Int("1".to_string()),
2850            }
2851        );
2852        assert_eq!(
2853            entry.stmts[2],
2854            CpsStmt::Primitive {
2855                dest: CpsValueId(1),
2856                op: typed_ir::PrimitiveOp::IntAdd,
2857                args: vec![CpsValueId(0), CpsValueId(2)],
2858            }
2859        );
2860        assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
2861        assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2862    }
2863
2864    #[test]
2865    fn inlines_small_structural_pure_direct_calls() {
2866        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2867            functions: vec![CpsFunction {
2868                name: "pair".to_string(),
2869                params: vec![CpsValueId(0), CpsValueId(1)],
2870                entry: CpsContinuationId(0),
2871                handlers: Vec::new(),
2872                continuations: vec![crate::cps_ir::CpsContinuation {
2873                    id: CpsContinuationId(0),
2874                    params: vec![CpsValueId(0), CpsValueId(1)],
2875                    captures: Vec::new(),
2876                    shot_kind: CpsShotKind::OneShot,
2877                    stmts: vec![CpsStmt::Tuple {
2878                        dest: CpsValueId(2),
2879                        items: vec![CpsValueId(0), CpsValueId(1)],
2880                    }],
2881                    terminator: CpsTerminator::Return(CpsValueId(2)),
2882                }],
2883            }],
2884            roots: vec![CpsFunction {
2885                name: "root".to_string(),
2886                params: Vec::new(),
2887                entry: CpsContinuationId(0),
2888                handlers: Vec::new(),
2889                continuations: vec![crate::cps_ir::CpsContinuation {
2890                    id: CpsContinuationId(0),
2891                    params: Vec::new(),
2892                    captures: Vec::new(),
2893                    shot_kind: CpsShotKind::OneShot,
2894                    stmts: vec![
2895                        CpsStmt::Literal {
2896                            dest: CpsValueId(0),
2897                            literal: CpsLiteral::Int("1".to_string()),
2898                        },
2899                        CpsStmt::Literal {
2900                            dest: CpsValueId(1),
2901                            literal: CpsLiteral::Int("2".to_string()),
2902                        },
2903                        CpsStmt::DirectCall {
2904                            dest: CpsValueId(2),
2905                            target: "pair".to_string(),
2906                            args: vec![CpsValueId(0), CpsValueId(1)],
2907                        },
2908                    ],
2909                    terminator: CpsTerminator::Return(CpsValueId(2)),
2910                }],
2911            }],
2912        }));
2913
2914        let optimized = optimize_cps_repr_abi_module(&abi);
2915        let entry = &optimized.module.roots[0].continuations[0];
2916
2917        assert_eq!(
2918            entry.stmts[2],
2919            CpsStmt::Tuple {
2920                dest: CpsValueId(2),
2921                items: vec![CpsValueId(0), CpsValueId(1)],
2922            }
2923        );
2924        assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
2925    }
2926
2927    #[test]
2928    fn rewrites_effectful_call_to_pure_callee() {
2929        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2930            functions: vec![CpsFunction {
2931                name: "plus_one".to_string(),
2932                params: vec![CpsValueId(0)],
2933                entry: CpsContinuationId(0),
2934                handlers: Vec::new(),
2935                continuations: vec![
2936                    crate::cps_ir::CpsContinuation {
2937                        id: CpsContinuationId(0),
2938                        params: vec![CpsValueId(0)],
2939                        captures: Vec::new(),
2940                        shot_kind: CpsShotKind::OneShot,
2941                        stmts: vec![
2942                            CpsStmt::Literal {
2943                                dest: CpsValueId(1),
2944                                literal: CpsLiteral::Int("1".to_string()),
2945                            },
2946                            CpsStmt::Primitive {
2947                                dest: CpsValueId(2),
2948                                op: typed_ir::PrimitiveOp::IntAdd,
2949                                args: vec![CpsValueId(0), CpsValueId(1)],
2950                            },
2951                        ],
2952                        terminator: CpsTerminator::Continue {
2953                            target: CpsContinuationId(1),
2954                            args: vec![CpsValueId(2)],
2955                        },
2956                    },
2957                    crate::cps_ir::CpsContinuation {
2958                        id: CpsContinuationId(1),
2959                        params: vec![CpsValueId(3)],
2960                        captures: Vec::new(),
2961                        shot_kind: CpsShotKind::OneShot,
2962                        stmts: Vec::new(),
2963                        terminator: CpsTerminator::Return(CpsValueId(3)),
2964                    },
2965                ],
2966            }],
2967            roots: vec![CpsFunction {
2968                name: "root".to_string(),
2969                params: Vec::new(),
2970                entry: CpsContinuationId(0),
2971                handlers: Vec::new(),
2972                continuations: vec![
2973                    crate::cps_ir::CpsContinuation {
2974                        id: CpsContinuationId(0),
2975                        params: Vec::new(),
2976                        captures: Vec::new(),
2977                        shot_kind: CpsShotKind::OneShot,
2978                        stmts: vec![CpsStmt::Literal {
2979                            dest: CpsValueId(0),
2980                            literal: CpsLiteral::Int("41".to_string()),
2981                        }],
2982                        terminator: CpsTerminator::EffectfulCall {
2983                            target: "plus_one".to_string(),
2984                            args: vec![CpsValueId(0)],
2985                            resume: CpsContinuationId(1),
2986                        },
2987                    },
2988                    crate::cps_ir::CpsContinuation {
2989                        id: CpsContinuationId(1),
2990                        params: vec![CpsValueId(1)],
2991                        captures: Vec::new(),
2992                        shot_kind: CpsShotKind::OneShot,
2993                        stmts: Vec::new(),
2994                        terminator: CpsTerminator::Return(CpsValueId(1)),
2995                    },
2996                ],
2997            }],
2998        }));
2999
3000        let optimized = optimize_cps_repr_abi_module(&abi);
3001        let entry = &optimized.module.roots[0].continuations[0];
3002
3003        assert_eq!(
3004            entry.stmts[1],
3005            CpsStmt::Literal {
3006                dest: CpsValueId(3),
3007                literal: CpsLiteral::Int("1".to_string()),
3008            }
3009        );
3010        assert_eq!(
3011            entry.stmts[2],
3012            CpsStmt::Primitive {
3013                dest: CpsValueId(2),
3014                op: typed_ir::PrimitiveOp::IntAdd,
3015                args: vec![CpsValueId(0), CpsValueId(3)],
3016            }
3017        );
3018        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(2)));
3019        assert_eq!(optimized.profile.rewritten_pure_effectful_calls, 1);
3020        assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
3021        assert_eq!(optimized.profile.returned_continuation_calls, 1);
3022    }
3023
3024    #[test]
3025    fn reifies_local_partial_closure_apply_to_direct_call() {
3026        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3027            functions: vec![CpsFunction {
3028                name: "add".to_string(),
3029                params: vec![CpsValueId(0), CpsValueId(1)],
3030                entry: CpsContinuationId(0),
3031                handlers: Vec::new(),
3032                continuations: vec![crate::cps_ir::CpsContinuation {
3033                    id: CpsContinuationId(0),
3034                    params: vec![CpsValueId(0), CpsValueId(1)],
3035                    captures: Vec::new(),
3036                    shot_kind: CpsShotKind::MultiShot,
3037                    stmts: vec![CpsStmt::Primitive {
3038                        dest: CpsValueId(2),
3039                        op: typed_ir::PrimitiveOp::IntAdd,
3040                        args: vec![CpsValueId(0), CpsValueId(1)],
3041                    }],
3042                    terminator: CpsTerminator::Return(CpsValueId(2)),
3043                }],
3044            }],
3045            roots: vec![CpsFunction {
3046                name: "root".to_string(),
3047                params: Vec::new(),
3048                entry: CpsContinuationId(0),
3049                handlers: Vec::new(),
3050                continuations: vec![
3051                    crate::cps_ir::CpsContinuation {
3052                        id: CpsContinuationId(0),
3053                        params: Vec::new(),
3054                        captures: Vec::new(),
3055                        shot_kind: CpsShotKind::OneShot,
3056                        stmts: vec![
3057                            CpsStmt::Literal {
3058                                dest: CpsValueId(0),
3059                                literal: CpsLiteral::Int("40".to_string()),
3060                            },
3061                            CpsStmt::MakeClosure {
3062                                dest: CpsValueId(1),
3063                                entry: CpsContinuationId(1),
3064                            },
3065                            CpsStmt::Literal {
3066                                dest: CpsValueId(2),
3067                                literal: CpsLiteral::Int("2".to_string()),
3068                            },
3069                            CpsStmt::ApplyClosure {
3070                                dest: CpsValueId(3),
3071                                closure: CpsValueId(1),
3072                                arg: CpsValueId(2),
3073                            },
3074                        ],
3075                        terminator: CpsTerminator::Return(CpsValueId(3)),
3076                    },
3077                    crate::cps_ir::CpsContinuation {
3078                        id: CpsContinuationId(1),
3079                        params: vec![CpsValueId(4)],
3080                        captures: vec![CpsValueId(0)],
3081                        shot_kind: CpsShotKind::OneShot,
3082                        stmts: vec![CpsStmt::DirectCall {
3083                            dest: CpsValueId(5),
3084                            target: "add".to_string(),
3085                            args: vec![CpsValueId(0), CpsValueId(4)],
3086                        }],
3087                        terminator: CpsTerminator::Return(CpsValueId(5)),
3088                    },
3089                ],
3090            }],
3091        }));
3092
3093        let optimized = optimize_cps_repr_abi_module(&abi);
3094        let entry = &optimized.module.roots[0].continuations[0];
3095
3096        assert_eq!(entry.stmts.len(), 3);
3097        assert_eq!(
3098            entry.stmts[2],
3099            CpsStmt::Primitive {
3100                dest: CpsValueId(3),
3101                op: typed_ir::PrimitiveOp::IntAdd,
3102                args: vec![CpsValueId(0), CpsValueId(2)],
3103            }
3104        );
3105        assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3106        assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
3107        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3108        assert_eq!(optimized.profile.direct_style_islands, 2);
3109        assert_eq!(optimized.profile.direct_style_continuations, 2);
3110    }
3111
3112    #[test]
3113    fn reifies_partial_closure_apply_after_inline() {
3114        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3115            functions: vec![CpsFunction {
3116                name: "add".to_string(),
3117                params: vec![CpsValueId(0), CpsValueId(1)],
3118                entry: CpsContinuationId(0),
3119                handlers: Vec::new(),
3120                continuations: vec![crate::cps_ir::CpsContinuation {
3121                    id: CpsContinuationId(0),
3122                    params: vec![CpsValueId(0), CpsValueId(1)],
3123                    captures: Vec::new(),
3124                    shot_kind: CpsShotKind::MultiShot,
3125                    stmts: vec![CpsStmt::Primitive {
3126                        dest: CpsValueId(2),
3127                        op: typed_ir::PrimitiveOp::IntAdd,
3128                        args: vec![CpsValueId(0), CpsValueId(1)],
3129                    }],
3130                    terminator: CpsTerminator::Return(CpsValueId(2)),
3131                }],
3132            }],
3133            roots: vec![CpsFunction {
3134                name: "root".to_string(),
3135                params: Vec::new(),
3136                entry: CpsContinuationId(0),
3137                handlers: Vec::new(),
3138                continuations: vec![
3139                    crate::cps_ir::CpsContinuation {
3140                        id: CpsContinuationId(0),
3141                        params: Vec::new(),
3142                        captures: Vec::new(),
3143                        shot_kind: CpsShotKind::OneShot,
3144                        stmts: vec![
3145                            CpsStmt::Literal {
3146                                dest: CpsValueId(0),
3147                                literal: CpsLiteral::Int("40".to_string()),
3148                            },
3149                            CpsStmt::MakeClosure {
3150                                dest: CpsValueId(1),
3151                                entry: CpsContinuationId(1),
3152                            },
3153                            CpsStmt::Literal {
3154                                dest: CpsValueId(2),
3155                                literal: CpsLiteral::Int("2".to_string()),
3156                            },
3157                        ],
3158                        terminator: CpsTerminator::Continue {
3159                            target: CpsContinuationId(2),
3160                            args: vec![CpsValueId(1), CpsValueId(2)],
3161                        },
3162                    },
3163                    crate::cps_ir::CpsContinuation {
3164                        id: CpsContinuationId(1),
3165                        params: vec![CpsValueId(4)],
3166                        captures: vec![CpsValueId(0)],
3167                        shot_kind: CpsShotKind::OneShot,
3168                        stmts: vec![CpsStmt::DirectCall {
3169                            dest: CpsValueId(5),
3170                            target: "add".to_string(),
3171                            args: vec![CpsValueId(0), CpsValueId(4)],
3172                        }],
3173                        terminator: CpsTerminator::Return(CpsValueId(5)),
3174                    },
3175                    crate::cps_ir::CpsContinuation {
3176                        id: CpsContinuationId(2),
3177                        params: vec![CpsValueId(6), CpsValueId(7)],
3178                        captures: Vec::new(),
3179                        shot_kind: CpsShotKind::OneShot,
3180                        stmts: vec![CpsStmt::ApplyClosure {
3181                            dest: CpsValueId(8),
3182                            closure: CpsValueId(6),
3183                            arg: CpsValueId(7),
3184                        }],
3185                        terminator: CpsTerminator::Return(CpsValueId(8)),
3186                    },
3187                ],
3188            }],
3189        }));
3190
3191        let optimized = optimize_cps_repr_abi_module(&abi);
3192        let entry = &optimized.module.roots[0].continuations[0];
3193
3194        assert_eq!(entry.stmts.len(), 3);
3195        assert_eq!(
3196            entry.stmts[2],
3197            CpsStmt::Primitive {
3198                dest: CpsValueId(8),
3199                op: typed_ir::PrimitiveOp::IntAdd,
3200                args: vec![CpsValueId(0), CpsValueId(2)],
3201            }
3202        );
3203        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(8)));
3204        assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3205        assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3206        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3207        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3208        assert_eq!(optimized.profile.direct_style_islands, 2);
3209        assert_eq!(optimized.profile.direct_style_continuations, 2);
3210    }
3211
3212    #[test]
3213    fn reifies_uncaptured_closure_apply_through_continuation_parameter() {
3214        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3215            functions: Vec::new(),
3216            roots: vec![CpsFunction {
3217                name: "root".to_string(),
3218                params: Vec::new(),
3219                entry: CpsContinuationId(0),
3220                handlers: Vec::new(),
3221                continuations: vec![
3222                    crate::cps_ir::CpsContinuation {
3223                        id: CpsContinuationId(0),
3224                        params: Vec::new(),
3225                        captures: Vec::new(),
3226                        shot_kind: CpsShotKind::OneShot,
3227                        stmts: vec![
3228                            CpsStmt::MakeClosure {
3229                                dest: CpsValueId(0),
3230                                entry: CpsContinuationId(1),
3231                            },
3232                            CpsStmt::Literal {
3233                                dest: CpsValueId(1),
3234                                literal: CpsLiteral::Int("7".to_string()),
3235                            },
3236                        ],
3237                        terminator: CpsTerminator::Continue {
3238                            target: CpsContinuationId(2),
3239                            args: vec![CpsValueId(0), CpsValueId(1)],
3240                        },
3241                    },
3242                    crate::cps_ir::CpsContinuation {
3243                        id: CpsContinuationId(1),
3244                        params: vec![CpsValueId(2)],
3245                        captures: Vec::new(),
3246                        shot_kind: CpsShotKind::OneShot,
3247                        stmts: vec![CpsStmt::Primitive {
3248                            dest: CpsValueId(3),
3249                            op: typed_ir::PrimitiveOp::IntToString,
3250                            args: vec![CpsValueId(2)],
3251                        }],
3252                        terminator: CpsTerminator::Return(CpsValueId(3)),
3253                    },
3254                    crate::cps_ir::CpsContinuation {
3255                        id: CpsContinuationId(2),
3256                        params: vec![CpsValueId(4), CpsValueId(5)],
3257                        captures: Vec::new(),
3258                        shot_kind: CpsShotKind::OneShot,
3259                        stmts: vec![CpsStmt::ApplyClosure {
3260                            dest: CpsValueId(6),
3261                            closure: CpsValueId(4),
3262                            arg: CpsValueId(5),
3263                        }],
3264                        terminator: CpsTerminator::Return(CpsValueId(6)),
3265                    },
3266                ],
3267            }],
3268        }));
3269
3270        let optimized = optimize_cps_repr_abi_module(&abi);
3271        let root = &optimized.module.roots[0];
3272        let entry = root
3273            .continuations
3274            .iter()
3275            .find(|continuation| continuation.id == CpsContinuationId(0))
3276            .unwrap();
3277
3278        assert!(root.continuations.iter().all(|continuation| {
3279            continuation
3280                .stmts
3281                .iter()
3282                .all(|stmt| !matches!(stmt, CpsStmt::ApplyClosure { .. }))
3283        }));
3284        assert!(entry.stmts.iter().any(|stmt| {
3285            matches!(
3286                stmt,
3287                CpsStmt::Primitive {
3288                    op: typed_ir::PrimitiveOp::IntToString,
3289                    args,
3290                    ..
3291                } if args == &vec![CpsValueId(1)]
3292            )
3293        }));
3294        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(6)));
3295        assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
3296        assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 1);
3297        assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3298        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3299        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3300    }
3301
3302    #[test]
3303    fn reifies_captured_closure_apply_when_captures_are_continuation_parameters() {
3304        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3305            functions: vec![CpsFunction {
3306                name: "add".to_string(),
3307                params: vec![CpsValueId(0), CpsValueId(1)],
3308                entry: CpsContinuationId(0),
3309                handlers: Vec::new(),
3310                continuations: vec![crate::cps_ir::CpsContinuation {
3311                    id: CpsContinuationId(0),
3312                    params: vec![CpsValueId(0), CpsValueId(1)],
3313                    captures: Vec::new(),
3314                    shot_kind: CpsShotKind::MultiShot,
3315                    stmts: vec![CpsStmt::Primitive {
3316                        dest: CpsValueId(2),
3317                        op: typed_ir::PrimitiveOp::IntAdd,
3318                        args: vec![CpsValueId(0), CpsValueId(1)],
3319                    }],
3320                    terminator: CpsTerminator::Return(CpsValueId(2)),
3321                }],
3322            }],
3323            roots: vec![CpsFunction {
3324                name: "root".to_string(),
3325                params: Vec::new(),
3326                entry: CpsContinuationId(0),
3327                handlers: Vec::new(),
3328                continuations: vec![
3329                    crate::cps_ir::CpsContinuation {
3330                        id: CpsContinuationId(0),
3331                        params: Vec::new(),
3332                        captures: Vec::new(),
3333                        shot_kind: CpsShotKind::OneShot,
3334                        stmts: vec![
3335                            CpsStmt::Literal {
3336                                dest: CpsValueId(0),
3337                                literal: CpsLiteral::Int("40".to_string()),
3338                            },
3339                            CpsStmt::MakeClosure {
3340                                dest: CpsValueId(1),
3341                                entry: CpsContinuationId(1),
3342                            },
3343                            CpsStmt::Literal {
3344                                dest: CpsValueId(2),
3345                                literal: CpsLiteral::Int("2".to_string()),
3346                            },
3347                        ],
3348                        terminator: CpsTerminator::Continue {
3349                            target: CpsContinuationId(2),
3350                            args: vec![CpsValueId(1), CpsValueId(0), CpsValueId(2)],
3351                        },
3352                    },
3353                    crate::cps_ir::CpsContinuation {
3354                        id: CpsContinuationId(1),
3355                        params: vec![CpsValueId(4)],
3356                        captures: vec![CpsValueId(0)],
3357                        shot_kind: CpsShotKind::OneShot,
3358                        stmts: vec![CpsStmt::DirectCall {
3359                            dest: CpsValueId(5),
3360                            target: "add".to_string(),
3361                            args: vec![CpsValueId(0), CpsValueId(4)],
3362                        }],
3363                        terminator: CpsTerminator::Return(CpsValueId(5)),
3364                    },
3365                    crate::cps_ir::CpsContinuation {
3366                        id: CpsContinuationId(2),
3367                        params: vec![CpsValueId(6), CpsValueId(7), CpsValueId(8)],
3368                        captures: Vec::new(),
3369                        shot_kind: CpsShotKind::OneShot,
3370                        stmts: vec![CpsStmt::ApplyClosure {
3371                            dest: CpsValueId(9),
3372                            closure: CpsValueId(6),
3373                            arg: CpsValueId(8),
3374                        }],
3375                        terminator: CpsTerminator::Return(CpsValueId(9)),
3376                    },
3377                ],
3378            }],
3379        }));
3380
3381        let optimized = optimize_cps_repr_abi_module(&abi);
3382        let root = &optimized.module.roots[0];
3383        let entry = root
3384            .continuations
3385            .iter()
3386            .find(|continuation| continuation.id == CpsContinuationId(0))
3387            .unwrap();
3388
3389        assert!(root.continuations.iter().all(|continuation| {
3390            continuation
3391                .stmts
3392                .iter()
3393                .all(|stmt| !matches!(stmt, CpsStmt::ApplyClosure { .. }))
3394        }));
3395        assert!(entry.stmts.iter().any(|stmt| {
3396            matches!(
3397                stmt,
3398                CpsStmt::Primitive {
3399                    op: typed_ir::PrimitiveOp::IntAdd,
3400                    args,
3401                    ..
3402                } if args == &vec![CpsValueId(0), CpsValueId(2)]
3403            )
3404        }));
3405        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(9)));
3406        assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
3407        assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 1);
3408        assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3409        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3410        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3411    }
3412
3413    #[test]
3414    fn reifies_local_effectful_apply_to_known_primitive_closure() {
3415        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3416            functions: Vec::new(),
3417            roots: vec![CpsFunction {
3418                name: "root".to_string(),
3419                params: Vec::new(),
3420                entry: CpsContinuationId(0),
3421                handlers: Vec::new(),
3422                continuations: vec![
3423                    crate::cps_ir::CpsContinuation {
3424                        id: CpsContinuationId(0),
3425                        params: Vec::new(),
3426                        captures: Vec::new(),
3427                        shot_kind: CpsShotKind::OneShot,
3428                        stmts: vec![
3429                            CpsStmt::MakeClosure {
3430                                dest: CpsValueId(0),
3431                                entry: CpsContinuationId(1),
3432                            },
3433                            CpsStmt::Literal {
3434                                dest: CpsValueId(1),
3435                                literal: CpsLiteral::Int("7".to_string()),
3436                            },
3437                        ],
3438                        terminator: CpsTerminator::EffectfulApply {
3439                            closure: CpsValueId(0),
3440                            arg: CpsValueId(1),
3441                            resume: CpsContinuationId(2),
3442                        },
3443                    },
3444                    crate::cps_ir::CpsContinuation {
3445                        id: CpsContinuationId(1),
3446                        params: vec![CpsValueId(2)],
3447                        captures: Vec::new(),
3448                        shot_kind: CpsShotKind::OneShot,
3449                        stmts: vec![CpsStmt::Primitive {
3450                            dest: CpsValueId(3),
3451                            op: typed_ir::PrimitiveOp::IntToString,
3452                            args: vec![CpsValueId(2)],
3453                        }],
3454                        terminator: CpsTerminator::Return(CpsValueId(3)),
3455                    },
3456                    crate::cps_ir::CpsContinuation {
3457                        id: CpsContinuationId(2),
3458                        params: vec![CpsValueId(4)],
3459                        captures: Vec::new(),
3460                        shot_kind: CpsShotKind::OneShot,
3461                        stmts: Vec::new(),
3462                        terminator: CpsTerminator::Return(CpsValueId(4)),
3463                    },
3464                ],
3465            }],
3466        }));
3467
3468        let optimized = optimize_cps_repr_abi_module(&abi);
3469        let root = &optimized.module.roots[0];
3470        let entry = root
3471            .continuations
3472            .iter()
3473            .find(|continuation| continuation.id == CpsContinuationId(0))
3474            .unwrap();
3475
3476        assert!(root.continuations.iter().all(|continuation| {
3477            !matches!(
3478                continuation.terminator,
3479                CpsTerminator::EffectfulApply { .. }
3480            )
3481        }));
3482        assert!(entry.stmts.iter().any(|stmt| {
3483            matches!(
3484                stmt,
3485                CpsStmt::Primitive {
3486                    op: typed_ir::PrimitiveOp::IntToString,
3487                    args,
3488                    ..
3489                } if args == &vec![CpsValueId(1)]
3490            )
3491        }));
3492        assert!(matches!(entry.terminator, CpsTerminator::Return(_)));
3493        assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3494        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3495        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3496    }
3497
3498    #[test]
3499    fn removes_dead_pure_value_statements() {
3500        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3501            functions: Vec::new(),
3502            roots: vec![CpsFunction {
3503                name: "root".to_string(),
3504                params: Vec::new(),
3505                entry: CpsContinuationId(0),
3506                handlers: Vec::new(),
3507                continuations: vec![crate::cps_ir::CpsContinuation {
3508                    id: CpsContinuationId(0),
3509                    params: Vec::new(),
3510                    captures: Vec::new(),
3511                    shot_kind: CpsShotKind::OneShot,
3512                    stmts: vec![
3513                        CpsStmt::Literal {
3514                            dest: CpsValueId(0),
3515                            literal: CpsLiteral::Int("1".to_string()),
3516                        },
3517                        CpsStmt::Literal {
3518                            dest: CpsValueId(1),
3519                            literal: CpsLiteral::Int("2".to_string()),
3520                        },
3521                        CpsStmt::Tuple {
3522                            dest: CpsValueId(2),
3523                            items: vec![CpsValueId(0), CpsValueId(1)],
3524                        },
3525                    ],
3526                    terminator: CpsTerminator::Return(CpsValueId(0)),
3527                }],
3528            }],
3529        }));
3530
3531        let optimized = optimize_cps_repr_abi_module(&abi);
3532        let entry = &optimized.module.roots[0].continuations[0];
3533
3534        assert_eq!(
3535            entry.stmts,
3536            vec![CpsStmt::Literal {
3537                dest: CpsValueId(0),
3538                literal: CpsLiteral::Int("1".to_string()),
3539            }]
3540        );
3541        assert_eq!(optimized.profile.removed_dead_pure_statements, 2);
3542    }
3543
3544    #[test]
3545    fn removes_dead_total_primitives_and_structural_projections() {
3546        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3547            functions: Vec::new(),
3548            roots: vec![CpsFunction {
3549                name: "root".to_string(),
3550                params: Vec::new(),
3551                entry: CpsContinuationId(0),
3552                handlers: Vec::new(),
3553                continuations: vec![crate::cps_ir::CpsContinuation {
3554                    id: CpsContinuationId(0),
3555                    params: Vec::new(),
3556                    captures: Vec::new(),
3557                    shot_kind: CpsShotKind::OneShot,
3558                    stmts: vec![
3559                        CpsStmt::Literal {
3560                            dest: CpsValueId(0),
3561                            literal: CpsLiteral::Int("1".to_string()),
3562                        },
3563                        CpsStmt::Literal {
3564                            dest: CpsValueId(1),
3565                            literal: CpsLiteral::Int("2".to_string()),
3566                        },
3567                        CpsStmt::Primitive {
3568                            dest: CpsValueId(2),
3569                            op: typed_ir::PrimitiveOp::IntAdd,
3570                            args: vec![CpsValueId(0), CpsValueId(1)],
3571                        },
3572                        CpsStmt::Tuple {
3573                            dest: CpsValueId(3),
3574                            items: vec![CpsValueId(0), CpsValueId(1)],
3575                        },
3576                        CpsStmt::TupleGet {
3577                            dest: CpsValueId(4),
3578                            tuple: CpsValueId(3),
3579                            index: 1,
3580                        },
3581                    ],
3582                    terminator: CpsTerminator::Return(CpsValueId(0)),
3583                }],
3584            }],
3585        }));
3586
3587        let optimized = optimize_cps_repr_abi_module(&abi);
3588        let entry = &optimized.module.roots[0].continuations[0];
3589
3590        assert_eq!(
3591            entry.stmts,
3592            vec![CpsStmt::Literal {
3593                dest: CpsValueId(0),
3594                literal: CpsLiteral::Int("1".to_string()),
3595            }]
3596        );
3597        assert_eq!(optimized.profile.folded_structural_projections, 1);
3598        assert_eq!(optimized.profile.removed_dead_pure_statements, 3);
3599    }
3600
3601    #[test]
3602    fn folds_tuple_get_from_local_tuple() {
3603        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3604            functions: Vec::new(),
3605            roots: vec![CpsFunction {
3606                name: "root".to_string(),
3607                params: Vec::new(),
3608                entry: CpsContinuationId(0),
3609                handlers: Vec::new(),
3610                continuations: vec![crate::cps_ir::CpsContinuation {
3611                    id: CpsContinuationId(0),
3612                    params: Vec::new(),
3613                    captures: Vec::new(),
3614                    shot_kind: CpsShotKind::OneShot,
3615                    stmts: vec![
3616                        CpsStmt::Literal {
3617                            dest: CpsValueId(0),
3618                            literal: CpsLiteral::Int("1".to_string()),
3619                        },
3620                        CpsStmt::Literal {
3621                            dest: CpsValueId(1),
3622                            literal: CpsLiteral::Int("2".to_string()),
3623                        },
3624                        CpsStmt::Tuple {
3625                            dest: CpsValueId(2),
3626                            items: vec![CpsValueId(0), CpsValueId(1)],
3627                        },
3628                        CpsStmt::TupleGet {
3629                            dest: CpsValueId(3),
3630                            tuple: CpsValueId(2),
3631                            index: 1,
3632                        },
3633                    ],
3634                    terminator: CpsTerminator::Return(CpsValueId(3)),
3635                }],
3636            }],
3637        }));
3638
3639        let optimized = optimize_cps_repr_abi_module(&abi);
3640        let entry = &optimized.module.roots[0].continuations[0];
3641
3642        assert_eq!(
3643            entry.stmts,
3644            vec![CpsStmt::Literal {
3645                dest: CpsValueId(1),
3646                literal: CpsLiteral::Int("2".to_string()),
3647            }]
3648        );
3649        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(1)));
3650        assert_eq!(optimized.profile.folded_structural_projections, 1);
3651        assert_eq!(optimized.profile.removed_dead_pure_statements, 2);
3652    }
3653
3654    #[test]
3655    fn removes_unused_multi_use_continuation_parameters() {
3656        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3657            functions: Vec::new(),
3658            roots: vec![CpsFunction {
3659                name: "root".to_string(),
3660                params: Vec::new(),
3661                entry: CpsContinuationId(0),
3662                handlers: Vec::new(),
3663                continuations: vec![
3664                    crate::cps_ir::CpsContinuation {
3665                        id: CpsContinuationId(0),
3666                        params: Vec::new(),
3667                        captures: Vec::new(),
3668                        shot_kind: CpsShotKind::OneShot,
3669                        stmts: vec![
3670                            CpsStmt::Literal {
3671                                dest: CpsValueId(0),
3672                                literal: CpsLiteral::Int("1".to_string()),
3673                            },
3674                            CpsStmt::Literal {
3675                                dest: CpsValueId(8),
3676                                literal: CpsLiteral::Bool(false),
3677                            },
3678                            CpsStmt::Primitive {
3679                                dest: CpsValueId(9),
3680                                op: typed_ir::PrimitiveOp::BoolNot,
3681                                args: vec![CpsValueId(8)],
3682                            },
3683                        ],
3684                        terminator: CpsTerminator::Branch {
3685                            cond: CpsValueId(9),
3686                            then_cont: CpsContinuationId(1),
3687                            else_cont: CpsContinuationId(2),
3688                        },
3689                    },
3690                    crate::cps_ir::CpsContinuation {
3691                        id: CpsContinuationId(1),
3692                        params: Vec::new(),
3693                        captures: Vec::new(),
3694                        shot_kind: CpsShotKind::OneShot,
3695                        stmts: vec![CpsStmt::Literal {
3696                            dest: CpsValueId(2),
3697                            literal: CpsLiteral::Int("2".to_string()),
3698                        }],
3699                        terminator: CpsTerminator::Continue {
3700                            target: CpsContinuationId(3),
3701                            args: vec![CpsValueId(0), CpsValueId(2)],
3702                        },
3703                    },
3704                    crate::cps_ir::CpsContinuation {
3705                        id: CpsContinuationId(2),
3706                        params: Vec::new(),
3707                        captures: Vec::new(),
3708                        shot_kind: CpsShotKind::OneShot,
3709                        stmts: vec![CpsStmt::Literal {
3710                            dest: CpsValueId(3),
3711                            literal: CpsLiteral::Int("3".to_string()),
3712                        }],
3713                        terminator: CpsTerminator::Continue {
3714                            target: CpsContinuationId(3),
3715                            args: vec![CpsValueId(0), CpsValueId(3)],
3716                        },
3717                    },
3718                    crate::cps_ir::CpsContinuation {
3719                        id: CpsContinuationId(3),
3720                        params: vec![CpsValueId(4), CpsValueId(5)],
3721                        captures: Vec::new(),
3722                        shot_kind: CpsShotKind::OneShot,
3723                        stmts: vec![
3724                            CpsStmt::Literal {
3725                                dest: CpsValueId(6),
3726                                literal: CpsLiteral::Int("0".to_string()),
3727                            },
3728                            CpsStmt::Primitive {
3729                                dest: CpsValueId(7),
3730                                op: typed_ir::PrimitiveOp::IntAdd,
3731                                args: vec![CpsValueId(5), CpsValueId(6)],
3732                            },
3733                        ],
3734                        terminator: CpsTerminator::Return(CpsValueId(7)),
3735                    },
3736                ],
3737            }],
3738        }));
3739
3740        let optimized = optimize_cps_repr_abi_module(&abi);
3741        let root = &optimized.module.roots[0];
3742        let join = root
3743            .continuations
3744            .iter()
3745            .find(|continuation| continuation.id == CpsContinuationId(3))
3746            .unwrap();
3747
3748        assert_eq!(
3749            join.params
3750                .iter()
3751                .map(|param| param.value)
3752                .collect::<Vec<_>>(),
3753            vec![CpsValueId(5)]
3754        );
3755        for source in [CpsContinuationId(1), CpsContinuationId(2)] {
3756            let continuation = root
3757                .continuations
3758                .iter()
3759                .find(|continuation| continuation.id == source)
3760                .unwrap();
3761            assert!(matches!(
3762                &continuation.terminator,
3763                CpsTerminator::Continue { args, .. } if args.len() == 1
3764            ));
3765        }
3766        assert_eq!(optimized.profile.removed_unused_continuation_params, 1);
3767        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3768    }
3769
3770    #[test]
3771    fn folds_constant_bool_branches_before_pruning() {
3772        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3773            functions: Vec::new(),
3774            roots: vec![CpsFunction {
3775                name: "root".to_string(),
3776                params: Vec::new(),
3777                entry: CpsContinuationId(0),
3778                handlers: Vec::new(),
3779                continuations: vec![
3780                    crate::cps_ir::CpsContinuation {
3781                        id: CpsContinuationId(0),
3782                        params: Vec::new(),
3783                        captures: Vec::new(),
3784                        shot_kind: CpsShotKind::OneShot,
3785                        stmts: vec![CpsStmt::Literal {
3786                            dest: CpsValueId(0),
3787                            literal: CpsLiteral::Bool(true),
3788                        }],
3789                        terminator: CpsTerminator::Branch {
3790                            cond: CpsValueId(0),
3791                            then_cont: CpsContinuationId(1),
3792                            else_cont: CpsContinuationId(2),
3793                        },
3794                    },
3795                    crate::cps_ir::CpsContinuation {
3796                        id: CpsContinuationId(1),
3797                        params: Vec::new(),
3798                        captures: Vec::new(),
3799                        shot_kind: CpsShotKind::OneShot,
3800                        stmts: vec![CpsStmt::Literal {
3801                            dest: CpsValueId(1),
3802                            literal: CpsLiteral::Int("1".to_string()),
3803                        }],
3804                        terminator: CpsTerminator::Return(CpsValueId(1)),
3805                    },
3806                    crate::cps_ir::CpsContinuation {
3807                        id: CpsContinuationId(2),
3808                        params: Vec::new(),
3809                        captures: Vec::new(),
3810                        shot_kind: CpsShotKind::OneShot,
3811                        stmts: vec![CpsStmt::Literal {
3812                            dest: CpsValueId(2),
3813                            literal: CpsLiteral::Int("2".to_string()),
3814                        }],
3815                        terminator: CpsTerminator::Return(CpsValueId(2)),
3816                    },
3817                ],
3818            }],
3819        }));
3820
3821        let optimized = optimize_cps_repr_abi_module(&abi);
3822        let entry = &optimized.module.roots[0].continuations[0];
3823
3824        assert_eq!(
3825            entry.stmts,
3826            vec![CpsStmt::Literal {
3827                dest: CpsValueId(1),
3828                literal: CpsLiteral::Int("1".to_string()),
3829            }]
3830        );
3831        assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(1)));
3832        assert_eq!(optimized.profile.folded_constant_branches, 1);
3833        assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3834        assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3835        assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3836    }
3837
3838    #[test]
3839    fn keeps_handler_arm_entries_when_pruning_unreachable_continuations() {
3840        let effect = yulang_typed_ir::Path::from_name(yulang_typed_ir::Name("ask".to_string()));
3841        let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3842            functions: Vec::new(),
3843            roots: vec![CpsFunction {
3844                name: "root".to_string(),
3845                params: Vec::new(),
3846                entry: CpsContinuationId(0),
3847                handlers: vec![crate::cps_ir::CpsHandler {
3848                    id: crate::cps_ir::CpsHandlerId(0),
3849                    arms: vec![crate::cps_ir::CpsHandlerArm {
3850                        effect,
3851                        entry: CpsContinuationId(1),
3852                    }],
3853                }],
3854                continuations: vec![
3855                    crate::cps_ir::CpsContinuation {
3856                        id: CpsContinuationId(0),
3857                        params: Vec::new(),
3858                        captures: Vec::new(),
3859                        shot_kind: CpsShotKind::OneShot,
3860                        stmts: vec![CpsStmt::Literal {
3861                            dest: CpsValueId(0),
3862                            literal: CpsLiteral::Int("1".to_string()),
3863                        }],
3864                        terminator: CpsTerminator::Return(CpsValueId(0)),
3865                    },
3866                    crate::cps_ir::CpsContinuation {
3867                        id: CpsContinuationId(1),
3868                        params: vec![CpsValueId(1), CpsValueId(2)],
3869                        captures: Vec::new(),
3870                        shot_kind: CpsShotKind::MultiShot,
3871                        stmts: Vec::new(),
3872                        terminator: CpsTerminator::Return(CpsValueId(1)),
3873                    },
3874                    crate::cps_ir::CpsContinuation {
3875                        id: CpsContinuationId(2),
3876                        params: Vec::new(),
3877                        captures: Vec::new(),
3878                        shot_kind: CpsShotKind::OneShot,
3879                        stmts: Vec::new(),
3880                        terminator: CpsTerminator::Return(CpsValueId(0)),
3881                    },
3882                ],
3883            }],
3884        }));
3885
3886        let optimized = optimize_cps_repr_abi_module(&abi);
3887        let ids = optimized.module.roots[0]
3888            .continuations
3889            .iter()
3890            .map(|continuation| continuation.id)
3891            .collect::<Vec<_>>();
3892
3893        assert_eq!(ids, vec![CpsContinuationId(0), CpsContinuationId(1)]);
3894        assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
3895    }
3896
3897    fn sample_abi_module() -> CpsReprAbiModule {
3898        lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3899            functions: Vec::new(),
3900            roots: vec![CpsFunction {
3901                name: "root".to_string(),
3902                params: Vec::new(),
3903                entry: CpsContinuationId(0),
3904                handlers: Vec::new(),
3905                continuations: vec![crate::cps_ir::CpsContinuation {
3906                    id: CpsContinuationId(0),
3907                    params: Vec::new(),
3908                    captures: Vec::new(),
3909                    shot_kind: CpsShotKind::OneShot,
3910                    stmts: vec![CpsStmt::Literal {
3911                        dest: CpsValueId(0),
3912                        literal: CpsLiteral::Int("42".to_string()),
3913                    }],
3914                    terminator: CpsTerminator::Return(CpsValueId(0)),
3915                }],
3916            }],
3917        }))
3918    }
3919}