Skip to main content

shape_vm/mir/
types.rs

1//! Core MIR types: Place, Statement, Terminator, BasicBlock.
2//!
3//! These represent the mid-level IR that the borrow solver operates on.
4//! Places track what can be borrowed (locals, fields, indices).
5//! Statements and terminators form basic blocks in a control flow graph.
6
7use shape_ast::ast::Span;
8use std::fmt;
9
10// ── Identifiers ──────────────────────────────────────────────────────
11
12/// Index of a local variable slot.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct SlotId(pub u16);
15
16/// Index of a struct/object field.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct FieldIdx(pub u16);
19
20/// Index of a basic block within a MIR function.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
22pub struct BasicBlockId(pub u32);
23
24/// A program point (statement index within the function's linearized MIR).
25/// Used as the "point" dimension in Datafrog relations.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
27pub struct Point(pub u32);
28
29/// Unique identifier for a loan (borrow).
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
31pub struct LoanId(pub u32);
32
33impl fmt::Display for SlotId {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "_{}", self.0)
36    }
37}
38
39impl fmt::Display for BasicBlockId {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        write!(f, "bb{}", self.0)
42    }
43}
44
45impl fmt::Display for Point {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        write!(f, "p{}", self.0)
48    }
49}
50
51impl fmt::Display for LoanId {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "L{}", self.0)
54    }
55}
56
57// ── Places ───────────────────────────────────────────────────────────
58
59/// A place is something that can be borrowed or assigned to.
60/// Tracks granular access paths for disjoint borrow analysis.
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub enum Place {
63    /// A local variable: `x`
64    Local(SlotId),
65    /// A field of a place: `x.field_name`
66    Field(Box<Place>, FieldIdx),
67    /// An index into a place: `x[i]` — index analysis is conservative in v1.
68    /// The index operand is boxed to break the recursive type cycle (Place → Operand → Place).
69    Index(Box<Place>, Box<Operand>),
70    /// Dereferencing a reference: `*r`
71    Deref(Box<Place>),
72}
73
74impl Place {
75    /// Get the root local of this place (e.g., `x.a.b` → `x`).
76    pub fn root_local(&self) -> SlotId {
77        match self {
78            Place::Local(slot) => *slot,
79            Place::Field(base, _) | Place::Index(base, _) | Place::Deref(base) => base.root_local(),
80        }
81    }
82
83    /// Check if this place is a prefix of another (for conflict detection).
84    /// `x` is a prefix of `x.a`, `x` is a prefix of `x[i]`, etc.
85    pub fn is_prefix_of(&self, other: &Place) -> bool {
86        if self == other {
87            return true;
88        }
89        match other {
90            Place::Local(_) => false,
91            Place::Field(base, _) | Place::Index(base, _) | Place::Deref(base) => {
92                self.is_prefix_of(base)
93            }
94        }
95    }
96
97    /// Check whether two places conflict (one borrows/writes something the other uses).
98    /// Two places conflict if one is a prefix of the other, or they're the same.
99    /// In v1, disjoint field borrows are tracked (x.a and x.b don't conflict),
100    /// but index borrows are conservative (x[i] and x[j] always conflict).
101    pub fn conflicts_with(&self, other: &Place) -> bool {
102        // Same root?
103        if self.root_local() != other.root_local() {
104            return false;
105        }
106        // Walk both paths to check overlap
107        self.is_prefix_of(other) || other.is_prefix_of(self) || self.overlaps(other)
108    }
109
110    fn overlaps(&self, other: &Place) -> bool {
111        match (self, other) {
112            (Place::Local(a), Place::Local(b)) => a == b,
113            // Disjoint fields: x.a and x.b do NOT conflict
114            (Place::Field(base_a, field_a), Place::Field(base_b, field_b)) => {
115                if base_a == base_b {
116                    field_a == field_b
117                } else {
118                    base_a.overlaps(base_b)
119                }
120            }
121            // Conservative: x[i] and x[j] always conflict
122            (Place::Index(base_a, _), Place::Index(base_b, _)) => base_a.overlaps(base_b),
123            _ => self.is_prefix_of(other) || other.is_prefix_of(self),
124        }
125    }
126}
127
128impl fmt::Display for Place {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        match self {
131            Place::Local(slot) => write!(f, "{}", slot),
132            Place::Field(base, field) => write!(f, "{}.{}", base, field.0),
133            Place::Index(base, idx) => write!(f, "{}[{}]", base, idx),
134            Place::Deref(base) => write!(f, "*{}", base),
135        }
136    }
137}
138
139// ── Operands ─────────────────────────────────────────────────────────
140
141/// An operand in an Rvalue or terminator.
142#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub enum Operand {
144    /// Copy the value from a place (for Copy types).
145    Copy(Place),
146    /// Move the value from a place (invalidates the source).
147    Move(Place),
148    /// A constant value.
149    Constant(MirConstant),
150}
151
152impl fmt::Display for Operand {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        match self {
155            Operand::Copy(p) => write!(f, "copy {}", p),
156            Operand::Move(p) => write!(f, "move {}", p),
157            Operand::Constant(c) => write!(f, "{}", c),
158        }
159    }
160}
161
162/// A constant value in MIR.
163#[derive(Debug, Clone, PartialEq, Eq, Hash)]
164pub enum MirConstant {
165    Int(i64),
166    Bool(bool),
167    None,
168    /// String interned index
169    StringId(u32),
170    /// Float (stored as bits for Eq/Hash)
171    Float(u64),
172    /// Function reference by name
173    Function(String),
174}
175
176impl fmt::Display for MirConstant {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            MirConstant::Int(v) => write!(f, "{}", v),
180            MirConstant::Bool(v) => write!(f, "{}", v),
181            MirConstant::None => write!(f, "none"),
182            MirConstant::StringId(id) => write!(f, "str#{}", id),
183            MirConstant::Float(bits) => write!(f, "{}", f64::from_bits(*bits)),
184            MirConstant::Function(name) => write!(f, "fn:{}", name),
185        }
186    }
187}
188
189// ── Rvalues ──────────────────────────────────────────────────────────
190
191/// The kind of borrow.
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
193pub enum BorrowKind {
194    /// Shared (immutable) borrow: `&x`
195    Shared,
196    /// Exclusive (mutable) borrow: `&mut x`
197    Exclusive,
198}
199
200impl fmt::Display for BorrowKind {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        match self {
203            BorrowKind::Shared => write!(f, "&"),
204            BorrowKind::Exclusive => write!(f, "&mut"),
205        }
206    }
207}
208
209/// Right-hand side of an assignment.
210#[derive(Debug, Clone, PartialEq)]
211pub enum Rvalue {
212    /// Use an operand directly.
213    Use(Operand),
214    /// Create a borrow: `&place` or `&mut place`
215    Borrow(BorrowKind, Place),
216    /// Binary operation.
217    BinaryOp(BinOp, Operand, Operand),
218    /// Unary operation.
219    UnaryOp(UnOp, Operand),
220    /// Function call result (arguments passed via terminator).
221    /// This is a placeholder — actual calls use Call terminators.
222    Aggregate(Vec<Operand>),
223    /// Clone of a value (explicit or auto-inferred).
224    Clone(Operand),
225}
226
227/// Binary operations in MIR.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BinOp {
230    Add,
231    Sub,
232    Mul,
233    Div,
234    Mod,
235    Eq,
236    Ne,
237    Lt,
238    Le,
239    Gt,
240    Ge,
241    And,
242    Or,
243}
244
245/// Unary operations in MIR.
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247pub enum UnOp {
248    Neg,
249    Not,
250}
251
252// ── Statements ───────────────────────────────────────────────────────
253
254/// A statement within a basic block (doesn't affect control flow).
255#[derive(Debug, Clone, PartialEq)]
256pub struct MirStatement {
257    pub kind: StatementKind,
258    pub span: Span,
259    /// The program point of this statement (assigned during linearization).
260    pub point: Point,
261}
262
263#[derive(Debug, Clone, PartialEq)]
264pub enum StatementKind {
265    /// Assign a value to a place: `place = rvalue`
266    Assign(Place, Rvalue),
267    /// Drop a place (scope exit, explicit drop).
268    /// Generates invalidation facts for any loans on this place.
269    Drop(Place),
270    /// No-op (placeholder, padding).
271    Nop,
272}
273
274// ── Terminators ──────────────────────────────────────────────────────
275
276/// A block terminator (controls flow between basic blocks).
277#[derive(Debug, Clone, PartialEq)]
278pub struct Terminator {
279    pub kind: TerminatorKind,
280    pub span: Span,
281}
282
283#[derive(Debug, Clone, PartialEq)]
284pub enum TerminatorKind {
285    /// Unconditional jump.
286    Goto(BasicBlockId),
287    /// Conditional branch.
288    SwitchBool {
289        operand: Operand,
290        true_bb: BasicBlockId,
291        false_bb: BasicBlockId,
292    },
293    /// Function call.
294    Call {
295        func: Operand,
296        args: Vec<Operand>,
297        /// Where to store the return value.
298        destination: Place,
299        /// Block to jump to after the call returns.
300        next: BasicBlockId,
301    },
302    /// Return from function.
303    Return,
304    /// Unreachable (after diverging calls, infinite loops).
305    Unreachable,
306}
307
308// ── Basic Blocks ─────────────────────────────────────────────────────
309
310/// A basic block: a sequence of statements ending in a terminator.
311#[derive(Debug, Clone)]
312pub struct BasicBlock {
313    pub id: BasicBlockId,
314    pub statements: Vec<MirStatement>,
315    pub terminator: Terminator,
316}
317
318// ── MIR Function ─────────────────────────────────────────────────────
319
320/// The MIR representation of a single function.
321#[derive(Debug, Clone)]
322pub struct MirFunction {
323    pub name: String,
324    /// The basic blocks forming the CFG.
325    pub blocks: Vec<BasicBlock>,
326    /// Number of local variable slots.
327    pub num_locals: u16,
328    /// Which locals are function parameters.
329    pub param_slots: Vec<SlotId>,
330    /// Type information for locals (for Copy/Clone inference).
331    pub local_types: Vec<LocalTypeInfo>,
332    /// Source span of the function.
333    pub span: Span,
334}
335
336/// Type information for a local variable, used for Copy/Clone inference.
337#[derive(Debug, Clone, PartialEq, Eq)]
338pub enum LocalTypeInfo {
339    /// Primitive (int, number, bool, none) — implicitly Copy, no borrow tracking.
340    Copy,
341    /// Heap type (String, Array, TypedObject, etc.) — requires borrow/move/clone tracking.
342    NonCopy,
343    /// Unknown type (will be resolved during analysis).
344    Unknown,
345}
346
347impl MirFunction {
348    /// Get the entry block (always block 0).
349    pub fn entry_block(&self) -> BasicBlockId {
350        BasicBlockId(0)
351    }
352
353    /// Iterate over all blocks.
354    pub fn iter_blocks(&self) -> impl Iterator<Item = &BasicBlock> {
355        self.blocks.iter()
356    }
357
358    /// Get a block by ID.
359    pub fn block(&self, id: BasicBlockId) -> &BasicBlock {
360        &self.blocks[id.0 as usize]
361    }
362
363    /// Linearize all statements into a flat list of points.
364    /// Returns (point, block_id, statement_index) triples.
365    pub fn all_points(&self) -> Vec<(Point, BasicBlockId, usize)> {
366        let mut points = Vec::new();
367        for block in &self.blocks {
368            for (i, stmt) in block.statements.iter().enumerate() {
369                points.push((stmt.point, block.id, i));
370            }
371        }
372        points
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_place_root_local() {
382        let p = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(1));
383        assert_eq!(p.root_local(), SlotId(0));
384    }
385
386    #[test]
387    fn test_place_prefix() {
388        let x = Place::Local(SlotId(0));
389        let xa = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(0));
390        assert!(x.is_prefix_of(&xa));
391        assert!(!xa.is_prefix_of(&x));
392    }
393
394    #[test]
395    fn test_disjoint_fields_no_conflict() {
396        let xa = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(0));
397        let xb = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(1));
398        // Disjoint fields should not overlap
399        assert!(!xa.overlaps(&xb));
400    }
401
402    #[test]
403    fn test_same_field_conflicts() {
404        let xa1 = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(0));
405        let xa2 = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(0));
406        assert!(xa1.conflicts_with(&xa2));
407    }
408
409    #[test]
410    fn test_different_locals_no_conflict() {
411        let x = Place::Local(SlotId(0));
412        let y = Place::Local(SlotId(1));
413        assert!(!x.conflicts_with(&y));
414    }
415
416    #[test]
417    fn test_parent_child_conflict() {
418        let x = Place::Local(SlotId(0));
419        let xa = Place::Field(Box::new(Place::Local(SlotId(0))), FieldIdx(0));
420        assert!(x.conflicts_with(&xa));
421        assert!(xa.conflicts_with(&x));
422    }
423}