Skip to main content

shape_vm/mir/
storage_planning.rs

1//! Storage Planning Pass — decides the runtime storage class for each binding.
2//!
3//! After MIR lowering and borrow analysis, this pass examines each local slot
4//! and assigns a `BindingStorageClass`:
5//!
6//! - `Direct`: Default for bindings that are never captured, never aliased, never escape.
7//! - `UniqueHeap`: For bindings that escape into closures with mutation (need Arc wrapper).
8//! - `SharedCow`: For `var` bindings that are aliased AND mutated (copy-on-write),
9//!   or for escaped mutable aliased bindings.
10//! - `Reference`: For bindings that hold first-class references.
11//! - `Deferred`: Only if analysis was incomplete (had fallbacks).
12//!
13//! The pass also computes `EscapeStatus` for each slot:
14//! - `Local`: Stays within the declaring scope.
15//! - `Captured`: Captured by a closure.
16//! - `Escaped`: Flows to the return slot (escapes the function).
17//!
18//! Escape status drives storage decisions (escaped+aliased+mutated → SharedCow)
19//! and is consumed by the post-solve relaxation pass to determine whether
20//! local containers can safely hold references.
21//!
22//! The pass runs once per function and produces a `StoragePlan` consumed by codegen.
23
24use std::collections::{HashMap, HashSet};
25
26use crate::mir::analysis::BorrowAnalysis;
27use crate::mir::types::*;
28use crate::type_tracking::{
29    Aliasability, BindingOwnershipClass, BindingSemantics, BindingStorageClass, EscapeStatus,
30    MutationCapability,
31};
32
33/// The computed storage plan for a single function.
34#[derive(Debug, Clone)]
35pub struct StoragePlan {
36    /// Maps each local slot to its decided storage class.
37    pub slot_classes: HashMap<SlotId, BindingStorageClass>,
38    /// Maps each local slot to its enriched binding semantics.
39    pub slot_semantics: HashMap<SlotId, BindingSemantics>,
40}
41
42/// Input bundle for the storage planner.
43pub struct StoragePlannerInput<'a> {
44    /// The MIR function to plan storage for.
45    pub mir: &'a MirFunction,
46    /// Borrow analysis results (includes liveness).
47    pub analysis: &'a BorrowAnalysis,
48    /// Per-slot ownership/storage semantics from the compiler's type tracker.
49    pub binding_semantics: &'a HashMap<u16, BindingSemantics>,
50    /// Slots captured by any closure in this function.
51    pub closure_captures: &'a HashSet<SlotId>,
52    /// Slots that are mutated inside a closure body.
53    pub mutable_captures: &'a HashSet<SlotId>,
54    /// Whether MIR lowering had fallbacks (incomplete analysis).
55    pub had_fallbacks: bool,
56}
57
58/// Scan MIR statements and terminators to find slots captured by closures.
59///
60/// Returns `(all_captures, mutable_captures)`:
61/// - `all_captures`: slots referenced in `ClosureCapture` statements
62/// - `mutable_captures`: subset of captured slots that are assigned more than
63///   once in the function (i.e., re-assigned after initial definition). A slot
64///   with only its initial definition assignment is not considered mutably captured.
65pub fn collect_closure_captures(mir: &MirFunction) -> (HashSet<SlotId>, HashSet<SlotId>) {
66    let mut all_captures = HashSet::new();
67    let mut assign_counts: HashMap<SlotId, u32> = HashMap::new();
68
69    for block in mir.iter_blocks() {
70        for stmt in &block.statements {
71            match &stmt.kind {
72                StatementKind::ClosureCapture { operands, .. } => {
73                    for op in operands {
74                        if let Some(slot) = operand_root_slot(op) {
75                            all_captures.insert(slot);
76                        }
77                    }
78                }
79                StatementKind::Assign(place, _) => {
80                    if let Place::Local(slot) = place {
81                        *assign_counts.entry(*slot).or_insert(0) += 1;
82                    }
83                }
84                _ => {}
85            }
86        }
87    }
88
89    // A slot is "mutably captured" if it is captured AND assigned more than once
90    // (meaning it has re-assignments beyond its initial definition).
91    let mutable_captures: HashSet<SlotId> = all_captures
92        .iter()
93        .filter(|slot| assign_counts.get(slot).copied().unwrap_or(0) > 1)
94        .copied()
95        .collect();
96
97    (all_captures, mutable_captures)
98}
99
100/// Extract the root SlotId from an operand, if it references a local.
101fn operand_root_slot(op: &Operand) -> Option<SlotId> {
102    match op {
103        Operand::Copy(place) | Operand::Move(place) | Operand::MoveExplicit(place) => {
104            Some(place.root_local())
105        }
106        Operand::Constant(_) => None,
107    }
108}
109
110/// Check whether a slot has any active loan (borrow) in the analysis.
111/// A slot with loans is holding or being borrowed as a reference.
112fn slot_has_active_loans(slot: SlotId, analysis: &BorrowAnalysis) -> bool {
113    for loan_info in analysis.loans.values() {
114        if loan_info.borrowed_place.root_local() == slot {
115            return true;
116        }
117    }
118    false
119}
120
121/// Check whether a slot is aliased — it appears as an operand in more than
122/// one `Assign` rvalue across the function, or it is captured.
123fn slot_is_aliased(slot: SlotId, mir: &MirFunction, closure_captures: &HashSet<SlotId>) -> bool {
124    if closure_captures.contains(&slot) {
125        return true;
126    }
127
128    let mut use_count = 0u32;
129    for block in mir.iter_blocks() {
130        for stmt in &block.statements {
131            if let StatementKind::Assign(_, rvalue) = &stmt.kind {
132                if rvalue_uses_slot(rvalue, slot) {
133                    use_count += 1;
134                    if use_count > 1 {
135                        return true;
136                    }
137                }
138            }
139        }
140        // Also check terminators for uses
141        if let TerminatorKind::Call { func, args, .. } = &block.terminator.kind {
142            if operand_uses_slot(func, slot) {
143                use_count += 1;
144            }
145            for arg in args {
146                if operand_uses_slot(arg, slot) {
147                    use_count += 1;
148                }
149            }
150            if use_count > 1 {
151                return true;
152            }
153        }
154    }
155    false
156}
157
158/// Check if a slot is mutated in the function (assigned to after initial definition).
159fn slot_is_mutated(slot: SlotId, mir: &MirFunction) -> bool {
160    let mut assign_count = 0u32;
161    for block in mir.iter_blocks() {
162        for stmt in &block.statements {
163            if let StatementKind::Assign(Place::Local(s), _) = &stmt.kind {
164                if *s == slot {
165                    assign_count += 1;
166                    if assign_count > 1 {
167                        return true;
168                    }
169                }
170            }
171        }
172    }
173    false
174}
175
176/// Check whether an rvalue uses (reads from) a given slot.
177fn rvalue_uses_slot(rvalue: &Rvalue, slot: SlotId) -> bool {
178    match rvalue {
179        Rvalue::Use(op) | Rvalue::Clone(op) | Rvalue::UnaryOp(_, op) => {
180            operand_uses_slot(op, slot)
181        }
182        Rvalue::Borrow(_, place) => place.root_local() == slot,
183        Rvalue::BinaryOp(_, lhs, rhs) => {
184            operand_uses_slot(lhs, slot) || operand_uses_slot(rhs, slot)
185        }
186        Rvalue::Aggregate(ops) => ops.iter().any(|op| operand_uses_slot(op, slot)),
187    }
188}
189
190/// Check whether an operand references a given slot.
191fn operand_uses_slot(op: &Operand, slot: SlotId) -> bool {
192    match op {
193        Operand::Copy(place) | Operand::Move(place) | Operand::MoveExplicit(place) => {
194            place.root_local() == slot
195        }
196        Operand::Constant(_) => false,
197    }
198}
199
200/// Run the storage planning pass on a single function.
201///
202/// The algorithm examines each local slot and decides its storage class:
203///
204/// 1. If `had_fallbacks` is true, all slots remain `Deferred` (analysis incomplete).
205/// 2. For each slot, check closure captures, mutations, aliasing, and loans.
206/// 3. Assign the appropriate `BindingStorageClass`.
207pub fn plan_storage(input: &StoragePlannerInput<'_>) -> StoragePlan {
208    let mut slot_classes = HashMap::new();
209    let mut slot_semantics = HashMap::new();
210
211    // If MIR lowering had fallbacks, we cannot trust the analysis.
212    // Leave everything Deferred so codegen uses conservative paths.
213    if input.had_fallbacks {
214        for slot_idx in 0..input.mir.num_locals {
215            let slot = SlotId(slot_idx);
216            slot_classes.insert(slot, BindingStorageClass::Deferred);
217            slot_semantics.insert(
218                slot,
219                BindingSemantics {
220                    ownership_class: BindingOwnershipClass::OwnedImmutable,
221                    storage_class: BindingStorageClass::Deferred,
222                    aliasability: Aliasability::Unique,
223                    mutation_capability: MutationCapability::Immutable,
224                    escape_status: EscapeStatus::Local,
225                },
226            );
227        }
228        return StoragePlan {
229            slot_classes,
230            slot_semantics,
231        };
232    }
233
234    for slot_idx in 0..input.mir.num_locals {
235        let slot = SlotId(slot_idx);
236        let (storage_class, semantics) = decide_slot_storage(slot, input);
237        slot_classes.insert(slot, storage_class);
238        slot_semantics.insert(slot, semantics);
239    }
240
241    StoragePlan {
242        slot_classes,
243        slot_semantics,
244    }
245}
246
247/// Decide the storage class for a single slot, returning both the storage class
248/// and enriched binding semantics.
249/// Decide the storage class and enriched semantics for a single slot.
250///
251/// ## Decision matrix
252///
253/// Priority order (first matching rule wins):
254///
255/// | # | Condition                                      | Storage class  |
256/// |---|------------------------------------------------|----------------|
257/// | 0 | Explicit `Reference` already set               | `Reference`    |
258/// | 1 | Slot holds a first-class reference              | `Reference`    |
259/// | 2 | Captured by closure with mutation               | `UniqueHeap`   |
260/// | 3 | `var` (Flexible) + aliased + mutated            | `SharedCow`    |
261/// | 3b| Escaped + aliased + mutated (any ownership)     | `SharedCow`    |
262/// | 4 | Everything else                                 | `Direct`       |
263///
264/// Notes:
265/// - "Aliased" means either captured by a closure or referenced from multiple
266///   MIR places (e.g. through a borrow chain).
267/// - `UniqueHeap` and `SharedCow` both result in heap boxing at runtime, but
268///   `SharedCow` adds copy-on-write semantics for safe shared mutation.
269/// - Immutable closure captures stay `Direct` — the closure gets a plain copy.
270fn decide_slot_storage(
271    slot: SlotId,
272    input: &StoragePlannerInput<'_>,
273) -> (BindingStorageClass, BindingSemantics) {
274    let is_captured = input.closure_captures.contains(&slot);
275    let is_mutably_captured = input.mutable_captures.contains(&slot);
276    let _has_loans = slot_has_active_loans(slot, input.analysis);
277    let is_mutated = slot_is_mutated(slot, input.mir);
278    let is_aliased = slot_is_aliased(slot, input.mir, input.closure_captures);
279
280    // Look up ownership class from binding semantics
281    let ownership = input
282        .binding_semantics
283        .get(&slot.0)
284        .map(|s| s.ownership_class);
285
286    // Check if the binding already has an explicit storage class set
287    let explicit_storage = input
288        .binding_semantics
289        .get(&slot.0)
290        .map(|s| s.storage_class);
291
292    let is_escaped = detect_escape_status(slot, input.mir, input.closure_captures)
293        == EscapeStatus::Escaped;
294
295    let storage_class = if let Some(BindingStorageClass::Reference) = explicit_storage {
296        // Already marked as a reference binding — preserve it.
297        BindingStorageClass::Reference
298    } else if slot_holds_reference(slot, input.mir) {
299        // Rule 1: Bindings that hold first-class references.
300        BindingStorageClass::Reference
301    } else if is_mutably_captured {
302        // Rule 2: Captured by closure with mutation → UniqueHeap.
303        BindingStorageClass::UniqueHeap
304    } else if matches!(ownership, Some(BindingOwnershipClass::Flexible))
305        && is_aliased
306        && is_mutated
307    {
308        // Rule 3: `var` bindings that are aliased AND mutated → SharedCow.
309        BindingStorageClass::SharedCow
310    } else if is_escaped && is_aliased && is_mutated {
311        // Rule 3b: Escaped mutable aliased bindings → SharedCow.
312        // Even non-Flexible bindings need COW when they escape with aliasing.
313        BindingStorageClass::SharedCow
314    } else {
315        // Rule 4: Captured by closure (immutably) — still Direct.
316        // Default: Direct storage (stack slot).
317        BindingStorageClass::Direct
318    };
319
320    // Compute enriched metadata
321    let aliasability = if is_captured || is_aliased {
322        if is_mutated {
323            Aliasability::SharedMutable
324        } else {
325            Aliasability::SharedImmutable
326        }
327    } else {
328        Aliasability::Unique
329    };
330
331    let mutation_capability = match (ownership, is_mutated) {
332        (Some(BindingOwnershipClass::OwnedImmutable), _) => MutationCapability::Immutable,
333        (Some(BindingOwnershipClass::OwnedMutable), _) => MutationCapability::LocalMutable,
334        (Some(BindingOwnershipClass::Flexible), true) => MutationCapability::SharedMutable,
335        (Some(BindingOwnershipClass::Flexible), false) => MutationCapability::Immutable,
336        (None, true) => MutationCapability::LocalMutable,
337        (None, false) => MutationCapability::Immutable,
338    };
339
340    let escape_status = detect_escape_status(slot, input.mir, input.closure_captures);
341
342    let enriched = BindingSemantics {
343        ownership_class: ownership.unwrap_or(BindingOwnershipClass::OwnedImmutable),
344        storage_class: storage_class,
345        aliasability,
346        mutation_capability,
347        escape_status,
348    };
349
350    (storage_class, enriched)
351}
352
353/// Detect the escape status of a slot by examining MIR dataflow.
354///
355/// - `Escaped`: The slot's value flows, directly or through local aliases, into
356///   the return slot (`SlotId(0)`).
357/// - `Captured`: The slot is captured by a closure.
358/// - `Local`: The slot stays within the declaring scope.
359pub fn detect_escape_status(
360    slot: SlotId,
361    mir: &MirFunction,
362    closure_captures: &HashSet<SlotId>,
363) -> EscapeStatus {
364    if slot != SlotId(0) {
365        let mut visited = HashSet::new();
366        if slot_flows_to_return(slot, mir, &mut visited) {
367            return EscapeStatus::Escaped;
368        }
369    }
370
371    if closure_captures.contains(&slot) {
372        EscapeStatus::Captured
373    } else {
374        EscapeStatus::Local
375    }
376}
377
378fn slot_flows_to_return(
379    slot: SlotId,
380    mir: &MirFunction,
381    visited: &mut HashSet<SlotId>,
382) -> bool {
383    if !visited.insert(slot) {
384        return false;
385    }
386
387    let return_slot = SlotId(0);
388    for block in mir.iter_blocks() {
389        for stmt in &block.statements {
390            let StatementKind::Assign(Place::Local(dest), rvalue) = &stmt.kind else {
391                continue;
392            };
393            if !rvalue_uses_slot(rvalue, slot) {
394                continue;
395            }
396            if *dest == return_slot {
397                return true;
398            }
399            if *dest != slot && slot_flows_to_return(*dest, mir, visited) {
400                return true;
401            }
402        }
403    }
404
405    false
406}
407
408/// Check if a slot was assigned a `Borrow` rvalue anywhere in the function.
409fn slot_holds_reference(slot: SlotId, mir: &MirFunction) -> bool {
410    for block in mir.iter_blocks() {
411        for stmt in &block.statements {
412            if let StatementKind::Assign(Place::Local(s), Rvalue::Borrow(_, _)) = &stmt.kind {
413                if *s == slot {
414                    return true;
415                }
416            }
417        }
418    }
419    false
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::mir::analysis::BorrowAnalysis;
426    use crate::mir::liveness::LivenessResult;
427    use crate::mir::types::*;
428    use crate::type_tracking::{
429        Aliasability, BindingOwnershipClass, BindingSemantics, BindingStorageClass, EscapeStatus,
430        MutationCapability,
431    };
432
433    fn span() -> shape_ast::ast::Span {
434        shape_ast::ast::Span { start: 0, end: 1 }
435    }
436
437    fn make_stmt(kind: StatementKind, point: u32) -> MirStatement {
438        MirStatement {
439            kind,
440            span: span(),
441            point: Point(point),
442        }
443    }
444
445    fn make_terminator(kind: TerminatorKind) -> Terminator {
446        Terminator { kind, span: span() }
447    }
448
449    fn empty_analysis() -> BorrowAnalysis {
450        BorrowAnalysis::empty()
451    }
452
453    /// Helper: create a simple MIR function with the given blocks.
454    fn make_mir(name: &str, blocks: Vec<BasicBlock>, num_locals: u16) -> MirFunction {
455        MirFunction {
456            name: name.to_string(),
457            blocks,
458            num_locals,
459            param_slots: vec![],
460            param_reference_kinds: vec![],
461            local_types: (0..num_locals).map(|_| LocalTypeInfo::Unknown).collect(),
462            span: span(),
463        }
464    }
465
466    // ── Test: Direct storage for simple binding ──────────────────────────
467
468    #[test]
469    fn test_simple_binding_gets_direct() {
470        // bb0: _0 = 42; return
471        let mir = make_mir(
472            "test_direct",
473            vec![BasicBlock {
474                id: BasicBlockId(0),
475                statements: vec![make_stmt(
476                    StatementKind::Assign(
477                        Place::Local(SlotId(0)),
478                        Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
479                    ),
480                    0,
481                )],
482                terminator: make_terminator(TerminatorKind::Return),
483            }],
484            1,
485        );
486
487        let analysis = empty_analysis();
488        let binding_semantics = HashMap::new();
489        let closure_captures = HashSet::new();
490        let mutable_captures = HashSet::new();
491
492        let input = StoragePlannerInput {
493            mir: &mir,
494            analysis: &analysis,
495            binding_semantics: &binding_semantics,
496            closure_captures: &closure_captures,
497            mutable_captures: &mutable_captures,
498            had_fallbacks: false,
499        };
500
501        let plan = plan_storage(&input);
502        assert_eq!(
503            plan.slot_classes.get(&SlotId(0)),
504            Some(&BindingStorageClass::Direct)
505        );
506    }
507
508    // ── Test: Deferred when had_fallbacks ─────────────────────────────────
509
510    #[test]
511    fn test_fallback_gives_deferred() {
512        let mir = make_mir(
513            "test_deferred",
514            vec![BasicBlock {
515                id: BasicBlockId(0),
516                statements: vec![],
517                terminator: make_terminator(TerminatorKind::Return),
518            }],
519            2,
520        );
521
522        let analysis = empty_analysis();
523        let binding_semantics = HashMap::new();
524        let closure_captures = HashSet::new();
525        let mutable_captures = HashSet::new();
526
527        let input = StoragePlannerInput {
528            mir: &mir,
529            analysis: &analysis,
530            binding_semantics: &binding_semantics,
531            closure_captures: &closure_captures,
532            mutable_captures: &mutable_captures,
533            had_fallbacks: true,
534        };
535
536        let plan = plan_storage(&input);
537        assert_eq!(
538            plan.slot_classes.get(&SlotId(0)),
539            Some(&BindingStorageClass::Deferred)
540        );
541        assert_eq!(
542            plan.slot_classes.get(&SlotId(1)),
543            Some(&BindingStorageClass::Deferred)
544        );
545    }
546
547    // ── Test: UniqueHeap for mutably captured slot ────────────────────────
548
549    #[test]
550    fn test_mutable_capture_gets_unique_heap() {
551        // bb0: _0 = 0; ClosureCapture(copy _0); _0 = 1; return
552        let mir = make_mir(
553            "test_unique_heap",
554            vec![BasicBlock {
555                id: BasicBlockId(0),
556                statements: vec![
557                    make_stmt(
558                        StatementKind::Assign(
559                            Place::Local(SlotId(0)),
560                            Rvalue::Use(Operand::Constant(MirConstant::Int(0))),
561                        ),
562                        0,
563                    ),
564                    make_stmt(
565                        StatementKind::ClosureCapture {
566                            closure_slot: SlotId(0),
567                            operands: vec![Operand::Copy(Place::Local(SlotId(0)))],
568                        },
569                        1,
570                    ),
571                    make_stmt(
572                        StatementKind::Assign(
573                            Place::Local(SlotId(0)),
574                            Rvalue::Use(Operand::Constant(MirConstant::Int(1))),
575                        ),
576                        2,
577                    ),
578                ],
579                terminator: make_terminator(TerminatorKind::Return),
580            }],
581            1,
582        );
583
584        let analysis = empty_analysis();
585        let binding_semantics = HashMap::new();
586
587        // Simulate what collect_closure_captures would find
588        let mut closure_captures = HashSet::new();
589        closure_captures.insert(SlotId(0));
590        let mut mutable_captures = HashSet::new();
591        mutable_captures.insert(SlotId(0));
592
593        let input = StoragePlannerInput {
594            mir: &mir,
595            analysis: &analysis,
596            binding_semantics: &binding_semantics,
597            closure_captures: &closure_captures,
598            mutable_captures: &mutable_captures,
599            had_fallbacks: false,
600        };
601
602        let plan = plan_storage(&input);
603        assert_eq!(
604            plan.slot_classes.get(&SlotId(0)),
605            Some(&BindingStorageClass::UniqueHeap)
606        );
607    }
608
609    // ── Test: SharedCow for aliased+mutated var binding ──────────────────
610
611    #[test]
612    fn test_aliased_mutated_var_gets_shared_cow() {
613        // bb0: _0 = "hello"; _1 = copy _0; _2 = copy _0; _0 = "world"; return
614        let mir = make_mir(
615            "test_shared_cow",
616            vec![BasicBlock {
617                id: BasicBlockId(0),
618                statements: vec![
619                    make_stmt(
620                        StatementKind::Assign(
621                            Place::Local(SlotId(0)),
622                            Rvalue::Use(Operand::Constant(MirConstant::StringId(0))),
623                        ),
624                        0,
625                    ),
626                    make_stmt(
627                        StatementKind::Assign(
628                            Place::Local(SlotId(1)),
629                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(0)))),
630                        ),
631                        1,
632                    ),
633                    make_stmt(
634                        StatementKind::Assign(
635                            Place::Local(SlotId(2)),
636                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(0)))),
637                        ),
638                        2,
639                    ),
640                    make_stmt(
641                        StatementKind::Assign(
642                            Place::Local(SlotId(0)),
643                            Rvalue::Use(Operand::Constant(MirConstant::StringId(1))),
644                        ),
645                        3,
646                    ),
647                ],
648                terminator: make_terminator(TerminatorKind::Return),
649            }],
650            3,
651        );
652
653        let analysis = empty_analysis();
654        let mut binding_semantics = HashMap::new();
655        // Mark slot 0 as a `var` (Flexible) binding
656        binding_semantics.insert(
657            0u16,
658            BindingSemantics::deferred(BindingOwnershipClass::Flexible),
659        );
660
661        let closure_captures = HashSet::new();
662        let mutable_captures = HashSet::new();
663
664        let input = StoragePlannerInput {
665            mir: &mir,
666            analysis: &analysis,
667            binding_semantics: &binding_semantics,
668            closure_captures: &closure_captures,
669            mutable_captures: &mutable_captures,
670            had_fallbacks: false,
671        };
672
673        let plan = plan_storage(&input);
674        assert_eq!(
675            plan.slot_classes.get(&SlotId(0)),
676            Some(&BindingStorageClass::SharedCow),
677            "aliased + mutated + Flexible => SharedCow"
678        );
679    }
680
681    // ── Test: Reference for borrow-holding slot ──────────────────────────
682
683    #[test]
684    fn test_borrow_holder_gets_reference() {
685        // bb0: _0 = 42; _1 = &_0; return
686        let mir = make_mir(
687            "test_reference",
688            vec![BasicBlock {
689                id: BasicBlockId(0),
690                statements: vec![
691                    make_stmt(
692                        StatementKind::Assign(
693                            Place::Local(SlotId(0)),
694                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
695                        ),
696                        0,
697                    ),
698                    make_stmt(
699                        StatementKind::Assign(
700                            Place::Local(SlotId(1)),
701                            Rvalue::Borrow(BorrowKind::Shared, Place::Local(SlotId(0))),
702                        ),
703                        1,
704                    ),
705                ],
706                terminator: make_terminator(TerminatorKind::Return),
707            }],
708            2,
709        );
710
711        // Create analysis with a loan on slot 0
712        let mut analysis = empty_analysis();
713        analysis.loans.insert(
714            LoanId(0),
715            crate::mir::analysis::LoanInfo {
716                id: LoanId(0),
717                borrowed_place: Place::Local(SlotId(0)),
718                kind: BorrowKind::Shared,
719                issued_at: Point(1),
720                span: span(),
721                region_depth: 1,
722            },
723        );
724
725        let binding_semantics = HashMap::new();
726        let closure_captures = HashSet::new();
727        let mutable_captures = HashSet::new();
728
729        let input = StoragePlannerInput {
730            mir: &mir,
731            analysis: &analysis,
732            binding_semantics: &binding_semantics,
733            closure_captures: &closure_captures,
734            mutable_captures: &mutable_captures,
735            had_fallbacks: false,
736        };
737
738        let plan = plan_storage(&input);
739        // _1 holds a borrow rvalue → Reference
740        assert_eq!(
741            plan.slot_classes.get(&SlotId(1)),
742            Some(&BindingStorageClass::Reference),
743            "_1 holds &_0 borrow → Reference"
744        );
745    }
746
747    // ── Test: Explicit Reference preserved ───────────────────────────────
748
749    #[test]
750    fn test_explicit_reference_preserved() {
751        let mir = make_mir(
752            "test_explicit_ref",
753            vec![BasicBlock {
754                id: BasicBlockId(0),
755                statements: vec![],
756                terminator: make_terminator(TerminatorKind::Return),
757            }],
758            1,
759        );
760
761        let analysis = empty_analysis();
762        let mut binding_semantics = HashMap::new();
763        binding_semantics.insert(
764            0u16,
765            BindingSemantics {
766                ownership_class: BindingOwnershipClass::OwnedImmutable,
767                storage_class: BindingStorageClass::Reference,
768                aliasability: Aliasability::Unique,
769                mutation_capability: MutationCapability::Immutable,
770                escape_status: EscapeStatus::Local,
771            },
772        );
773
774        let closure_captures = HashSet::new();
775        let mutable_captures = HashSet::new();
776
777        let input = StoragePlannerInput {
778            mir: &mir,
779            analysis: &analysis,
780            binding_semantics: &binding_semantics,
781            closure_captures: &closure_captures,
782            mutable_captures: &mutable_captures,
783            had_fallbacks: false,
784        };
785
786        let plan = plan_storage(&input);
787        assert_eq!(
788            plan.slot_classes.get(&SlotId(0)),
789            Some(&BindingStorageClass::Reference),
790            "explicit Reference annotation preserved"
791        );
792    }
793
794    // ── Test: collect_closure_captures ────────────────────────────────────
795
796    #[test]
797    fn test_collect_closure_captures() {
798        // bb0: _0 = 1; _1 = 2; ClosureCapture(copy _0, copy _1); _0 = 3; return
799        let mir = make_mir(
800            "test_collect",
801            vec![BasicBlock {
802                id: BasicBlockId(0),
803                statements: vec![
804                    make_stmt(
805                        StatementKind::Assign(
806                            Place::Local(SlotId(0)),
807                            Rvalue::Use(Operand::Constant(MirConstant::Int(1))),
808                        ),
809                        0,
810                    ),
811                    make_stmt(
812                        StatementKind::Assign(
813                            Place::Local(SlotId(1)),
814                            Rvalue::Use(Operand::Constant(MirConstant::Int(2))),
815                        ),
816                        1,
817                    ),
818                    make_stmt(
819                        StatementKind::ClosureCapture {
820                            closure_slot: SlotId(2),
821                            operands: vec![
822                                Operand::Copy(Place::Local(SlotId(0))),
823                                Operand::Copy(Place::Local(SlotId(1))),
824                            ],
825                        },
826                        2,
827                    ),
828                    make_stmt(
829                        StatementKind::Assign(
830                            Place::Local(SlotId(0)),
831                            Rvalue::Use(Operand::Constant(MirConstant::Int(3))),
832                        ),
833                        3,
834                    ),
835                ],
836                terminator: make_terminator(TerminatorKind::Return),
837            }],
838            2,
839        );
840
841        let (captures, mutable) = collect_closure_captures(&mir);
842        assert!(captures.contains(&SlotId(0)));
843        assert!(captures.contains(&SlotId(1)));
844        // _0 is assigned twice (before and after capture) → mutably captured
845        assert!(mutable.contains(&SlotId(0)));
846        // _1 is assigned only once (initial definition) → not mutably captured
847        // Note: our conservative check counts any assignment, but _1 only has one
848        assert!(!mutable.contains(&SlotId(1)));
849    }
850
851    // ── Test: Immutable captured slot stays Direct ───────────────────────
852
853    #[test]
854    fn test_immutable_capture_stays_direct() {
855        // bb0: _0 = 1; ClosureCapture(copy _0); return
856        let mir = make_mir(
857            "test_immutable_capture",
858            vec![BasicBlock {
859                id: BasicBlockId(0),
860                statements: vec![
861                    make_stmt(
862                        StatementKind::Assign(
863                            Place::Local(SlotId(0)),
864                            Rvalue::Use(Operand::Constant(MirConstant::Int(1))),
865                        ),
866                        0,
867                    ),
868                    make_stmt(
869                        StatementKind::ClosureCapture {
870                            closure_slot: SlotId(0),
871                            operands: vec![Operand::Copy(Place::Local(SlotId(0)))],
872                        },
873                        1,
874                    ),
875                ],
876                terminator: make_terminator(TerminatorKind::Return),
877            }],
878            1,
879        );
880
881        let analysis = empty_analysis();
882        let binding_semantics = HashMap::new();
883        let mut closure_captures = HashSet::new();
884        closure_captures.insert(SlotId(0));
885        let mutable_captures = HashSet::new();
886
887        let input = StoragePlannerInput {
888            mir: &mir,
889            analysis: &analysis,
890            binding_semantics: &binding_semantics,
891            closure_captures: &closure_captures,
892            mutable_captures: &mutable_captures,
893            had_fallbacks: false,
894        };
895
896        let plan = plan_storage(&input);
897        assert_eq!(
898            plan.slot_classes.get(&SlotId(0)),
899            Some(&BindingStorageClass::Direct),
900            "immutable capture stays Direct"
901        );
902    }
903
904    // ── Test: Non-Flexible ownership doesn't get SharedCow ───────────────
905
906    #[test]
907    fn test_owned_mutable_aliased_mutated_stays_direct() {
908        // A `let mut` binding that is aliased and mutated does NOT get
909        // SharedCow — only `var` (Flexible) does.
910        let mir = make_mir(
911            "test_let_mut_no_cow",
912            vec![BasicBlock {
913                id: BasicBlockId(0),
914                statements: vec![
915                    make_stmt(
916                        StatementKind::Assign(
917                            Place::Local(SlotId(0)),
918                            Rvalue::Use(Operand::Constant(MirConstant::Int(0))),
919                        ),
920                        0,
921                    ),
922                    make_stmt(
923                        StatementKind::Assign(
924                            Place::Local(SlotId(1)),
925                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(0)))),
926                        ),
927                        1,
928                    ),
929                    make_stmt(
930                        StatementKind::Assign(
931                            Place::Local(SlotId(2)),
932                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(0)))),
933                        ),
934                        2,
935                    ),
936                    make_stmt(
937                        StatementKind::Assign(
938                            Place::Local(SlotId(0)),
939                            Rvalue::Use(Operand::Constant(MirConstant::Int(99))),
940                        ),
941                        3,
942                    ),
943                ],
944                terminator: make_terminator(TerminatorKind::Return),
945            }],
946            3,
947        );
948
949        let analysis = empty_analysis();
950        let mut binding_semantics = HashMap::new();
951        binding_semantics.insert(
952            0u16,
953            BindingSemantics::deferred(BindingOwnershipClass::OwnedMutable),
954        );
955
956        let closure_captures = HashSet::new();
957        let mutable_captures = HashSet::new();
958
959        let input = StoragePlannerInput {
960            mir: &mir,
961            analysis: &analysis,
962            binding_semantics: &binding_semantics,
963            closure_captures: &closure_captures,
964            mutable_captures: &mutable_captures,
965            had_fallbacks: false,
966        };
967
968        let plan = plan_storage(&input);
969        assert_eq!(
970            plan.slot_classes.get(&SlotId(0)),
971            Some(&BindingStorageClass::Direct),
972            "OwnedMutable (let mut) stays Direct even when aliased+mutated"
973        );
974    }
975
976    // ── Test: All slots planned ──────────────────────────────────────────
977
978    #[test]
979    fn test_all_slots_planned() {
980        let mir = make_mir(
981            "test_all_planned",
982            vec![BasicBlock {
983                id: BasicBlockId(0),
984                statements: vec![],
985                terminator: make_terminator(TerminatorKind::Return),
986            }],
987            5,
988        );
989
990        let analysis = empty_analysis();
991        let binding_semantics = HashMap::new();
992        let closure_captures = HashSet::new();
993        let mutable_captures = HashSet::new();
994
995        let input = StoragePlannerInput {
996            mir: &mir,
997            analysis: &analysis,
998            binding_semantics: &binding_semantics,
999            closure_captures: &closure_captures,
1000            mutable_captures: &mutable_captures,
1001            had_fallbacks: false,
1002        };
1003
1004        let plan = plan_storage(&input);
1005        assert_eq!(plan.slot_classes.len(), 5, "all slots must be planned");
1006        for i in 0..5 {
1007            assert!(
1008                plan.slot_classes.contains_key(&SlotId(i)),
1009                "slot {} must be in plan",
1010                i
1011            );
1012        }
1013    }
1014
1015    // ── Test: UniqueHeap takes priority over SharedCow ───────────────────
1016
1017    #[test]
1018    fn test_mutable_capture_beats_shared_cow() {
1019        // A `var` binding that is both mutably captured AND aliased+mutated
1020        // should get UniqueHeap (closure mutation takes priority over COW).
1021        let mir = make_mir(
1022            "test_priority",
1023            vec![BasicBlock {
1024                id: BasicBlockId(0),
1025                statements: vec![
1026                    make_stmt(
1027                        StatementKind::Assign(
1028                            Place::Local(SlotId(0)),
1029                            Rvalue::Use(Operand::Constant(MirConstant::Int(0))),
1030                        ),
1031                        0,
1032                    ),
1033                    make_stmt(
1034                        StatementKind::ClosureCapture {
1035                            closure_slot: SlotId(0),
1036                            operands: vec![Operand::Copy(Place::Local(SlotId(0)))],
1037                        },
1038                        1,
1039                    ),
1040                    make_stmt(
1041                        StatementKind::Assign(
1042                            Place::Local(SlotId(0)),
1043                            Rvalue::Use(Operand::Constant(MirConstant::Int(1))),
1044                        ),
1045                        2,
1046                    ),
1047                ],
1048                terminator: make_terminator(TerminatorKind::Return),
1049            }],
1050            1,
1051        );
1052
1053        let analysis = empty_analysis();
1054        let mut binding_semantics = HashMap::new();
1055        binding_semantics.insert(
1056            0u16,
1057            BindingSemantics::deferred(BindingOwnershipClass::Flexible),
1058        );
1059
1060        let mut closure_captures = HashSet::new();
1061        closure_captures.insert(SlotId(0));
1062        let mut mutable_captures = HashSet::new();
1063        mutable_captures.insert(SlotId(0));
1064
1065        let input = StoragePlannerInput {
1066            mir: &mir,
1067            analysis: &analysis,
1068            binding_semantics: &binding_semantics,
1069            closure_captures: &closure_captures,
1070            mutable_captures: &mutable_captures,
1071            had_fallbacks: false,
1072        };
1073
1074        let plan = plan_storage(&input);
1075        assert_eq!(
1076            plan.slot_classes.get(&SlotId(0)),
1077            Some(&BindingStorageClass::UniqueHeap),
1078            "mutable capture → UniqueHeap overrides SharedCow"
1079        );
1080    }
1081
1082    // ── Test: detect_escape_status ───────────────────────────────────────
1083
1084    #[test]
1085    fn test_escape_status_local() {
1086        // bb0: _1 = 42; return
1087        // _1 never flows to _0 (return slot) → Local
1088        let mir = make_mir(
1089            "test_local_escape",
1090            vec![BasicBlock {
1091                id: BasicBlockId(0),
1092                statements: vec![make_stmt(
1093                    StatementKind::Assign(
1094                        Place::Local(SlotId(1)),
1095                        Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1096                    ),
1097                    0,
1098                )],
1099                terminator: make_terminator(TerminatorKind::Return),
1100            }],
1101            2,
1102        );
1103
1104        let captures = HashSet::new();
1105        assert_eq!(
1106            detect_escape_status(SlotId(1), &mir, &captures),
1107            EscapeStatus::Local,
1108            "slot that doesn't escape should be Local"
1109        );
1110    }
1111
1112    #[test]
1113    fn test_escape_status_escaped_via_return() {
1114        // bb0: _1 = 42; _0 = copy _1; return
1115        // _1 flows to return slot _0 → Escaped
1116        let mir = make_mir(
1117            "test_escaped",
1118            vec![BasicBlock {
1119                id: BasicBlockId(0),
1120                statements: vec![
1121                    make_stmt(
1122                        StatementKind::Assign(
1123                            Place::Local(SlotId(1)),
1124                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1125                        ),
1126                        0,
1127                    ),
1128                    make_stmt(
1129                        StatementKind::Assign(
1130                            Place::Local(SlotId(0)),
1131                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(1)))),
1132                        ),
1133                        1,
1134                    ),
1135                ],
1136                terminator: make_terminator(TerminatorKind::Return),
1137            }],
1138            2,
1139        );
1140
1141        let captures = HashSet::new();
1142        assert_eq!(
1143            detect_escape_status(SlotId(1), &mir, &captures),
1144            EscapeStatus::Escaped,
1145            "slot assigned to return slot should be Escaped"
1146        );
1147    }
1148
1149    #[test]
1150    fn test_escape_status_escaped_via_local_alias_chain() {
1151        // bb0: _2 = 42; _1 = copy _2; _0 = copy _1; return
1152        // _2 reaches the return slot transitively through _1.
1153        let mir = make_mir(
1154            "test_transitive_escape",
1155            vec![BasicBlock {
1156                id: BasicBlockId(0),
1157                statements: vec![
1158                    make_stmt(
1159                        StatementKind::Assign(
1160                            Place::Local(SlotId(2)),
1161                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1162                        ),
1163                        0,
1164                    ),
1165                    make_stmt(
1166                        StatementKind::Assign(
1167                            Place::Local(SlotId(1)),
1168                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(2)))),
1169                        ),
1170                        1,
1171                    ),
1172                    make_stmt(
1173                        StatementKind::Assign(
1174                            Place::Local(SlotId(0)),
1175                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(1)))),
1176                        ),
1177                        2,
1178                    ),
1179                ],
1180                terminator: make_terminator(TerminatorKind::Return),
1181            }],
1182            3,
1183        );
1184
1185        let captures = HashSet::new();
1186        assert_eq!(
1187            detect_escape_status(SlotId(2), &mir, &captures),
1188            EscapeStatus::Escaped,
1189            "slot flowing into a returned local alias should be Escaped"
1190        );
1191    }
1192
1193    #[test]
1194    fn test_escape_status_captured() {
1195        // bb0: _1 = 42; ClosureCapture(copy _1); return
1196        let mir = make_mir(
1197            "test_captured",
1198            vec![BasicBlock {
1199                id: BasicBlockId(0),
1200                statements: vec![
1201                    make_stmt(
1202                        StatementKind::Assign(
1203                            Place::Local(SlotId(1)),
1204                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1205                        ),
1206                        0,
1207                    ),
1208                    make_stmt(
1209                        StatementKind::ClosureCapture {
1210                            closure_slot: SlotId(1),
1211                            operands: vec![Operand::Copy(Place::Local(SlotId(1)))],
1212                        },
1213                        1,
1214                    ),
1215                ],
1216                terminator: make_terminator(TerminatorKind::Return),
1217            }],
1218            2,
1219        );
1220
1221        let mut captures = HashSet::new();
1222        captures.insert(SlotId(1));
1223        assert_eq!(
1224            detect_escape_status(SlotId(1), &mir, &captures),
1225            EscapeStatus::Captured,
1226            "slot captured by closure should be Captured"
1227        );
1228    }
1229
1230    #[test]
1231    fn test_escape_status_escaped_beats_captured() {
1232        // A slot that both escapes to return AND is captured → Escaped takes priority
1233        // bb0: _1 = 42; ClosureCapture(copy _1); _0 = copy _1; return
1234        let mir = make_mir(
1235            "test_escaped_captured",
1236            vec![BasicBlock {
1237                id: BasicBlockId(0),
1238                statements: vec![
1239                    make_stmt(
1240                        StatementKind::Assign(
1241                            Place::Local(SlotId(1)),
1242                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1243                        ),
1244                        0,
1245                    ),
1246                    make_stmt(
1247                        StatementKind::ClosureCapture {
1248                            closure_slot: SlotId(1),
1249                            operands: vec![Operand::Copy(Place::Local(SlotId(1)))],
1250                        },
1251                        1,
1252                    ),
1253                    make_stmt(
1254                        StatementKind::Assign(
1255                            Place::Local(SlotId(0)),
1256                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(1)))),
1257                        ),
1258                        2,
1259                    ),
1260                ],
1261                terminator: make_terminator(TerminatorKind::Return),
1262            }],
1263            2,
1264        );
1265
1266        let mut captures = HashSet::new();
1267        captures.insert(SlotId(1));
1268        assert_eq!(
1269            detect_escape_status(SlotId(1), &mir, &captures),
1270            EscapeStatus::Escaped,
1271            "Escaped takes priority over Captured"
1272        );
1273    }
1274
1275    #[test]
1276    fn test_escape_semantics_in_plan() {
1277        // Verify that the storage plan captures Escaped status on semantics
1278        // bb0: _1 = 42; _0 = copy _1; return
1279        let mir = make_mir(
1280            "test_escape_in_plan",
1281            vec![BasicBlock {
1282                id: BasicBlockId(0),
1283                statements: vec![
1284                    make_stmt(
1285                        StatementKind::Assign(
1286                            Place::Local(SlotId(1)),
1287                            Rvalue::Use(Operand::Constant(MirConstant::Int(42))),
1288                        ),
1289                        0,
1290                    ),
1291                    make_stmt(
1292                        StatementKind::Assign(
1293                            Place::Local(SlotId(0)),
1294                            Rvalue::Use(Operand::Copy(Place::Local(SlotId(1)))),
1295                        ),
1296                        1,
1297                    ),
1298                ],
1299                terminator: make_terminator(TerminatorKind::Return),
1300            }],
1301            2,
1302        );
1303
1304        let analysis = empty_analysis();
1305        let binding_semantics = HashMap::new();
1306        let closure_captures = HashSet::new();
1307        let mutable_captures = HashSet::new();
1308
1309        let input = StoragePlannerInput {
1310            mir: &mir,
1311            analysis: &analysis,
1312            binding_semantics: &binding_semantics,
1313            closure_captures: &closure_captures,
1314            mutable_captures: &mutable_captures,
1315            had_fallbacks: false,
1316        };
1317
1318        let plan = plan_storage(&input);
1319        assert_eq!(
1320            plan.slot_semantics.get(&SlotId(1)).map(|s| s.escape_status),
1321            Some(EscapeStatus::Escaped),
1322            "slot flowing to return should have Escaped status in plan"
1323        );
1324    }
1325}