Skip to main content

vyre_foundation/ir_inner/model/program/
canonical.rs

1use std::sync::Arc;
2
3use crate::ir_inner::model::expr::Expr;
4use crate::ir_inner::model::node::Node;
5use crate::ir_inner::model::types::BinOp;
6
7use super::{meta::buffer_decl_canonical_key, BufferDecl, Program};
8
9impl Program {
10    /// Return the canonical IR shape used for security-sensitive cache keys.
11    ///
12    /// Canonicalization preserves executable semantics while removing
13    /// authoring-order noise: buffer declarations are sorted by their stable
14    /// wire key, commutative expression operands are normalized, and `Block`
15    /// wrappers that do not own local bindings are flattened.
16    #[must_use]
17    pub fn canonicalized(&self) -> Self {
18        let mut buffers = self.buffers().to_vec();
19        sort_buffers(&mut buffers);
20        let mut ctx = CanonicalCtx::default();
21        self.with_rewritten_entry(ctx.canonicalize_nodes(self.entry()))
22            .with_rewritten_buffers(buffers)
23    }
24
25    /// Serialize the canonical IR shape into stable VIR0 wire bytes.
26    ///
27    /// # Errors
28    ///
29    /// Returns the same wire-format validation errors as [`Self::to_wire`],
30    /// but after canonical normalization has been applied.
31    #[must_use]
32    pub fn canonical_wire_bytes(&self) -> Result<Vec<u8>, crate::error::Error> {
33        let canonical = self.canonicalized();
34        let mut out = Vec::new();
35        crate::serial::wire::encode::to_wire_into(&canonical, &mut out)
36            .map_err(|message| crate::error::Error::WireFormatValidation { message })?;
37        Ok(out)
38    }
39
40    /// BLAKE3 digest of [`Self::canonical_wire_bytes`].
41    ///
42    /// # Errors
43    ///
44    /// Returns a wire-format validation error if the canonical program cannot
45    /// be represented by the current VIR0 encoder.
46    pub fn canonical_wire_hash(&self) -> Result<blake3::Hash, crate::error::Error> {
47        self.canonical_wire_bytes()
48            .map(|bytes| blake3::hash(&bytes))
49    }
50}
51
52fn sort_buffers(buffers: &mut [BufferDecl]) {
53    buffers.sort_by_cached_key(buffer_decl_canonical_key);
54}
55
56#[derive(Default)]
57struct CanonicalCtx {
58    left_key: Vec<u8>,
59    right_key: Vec<u8>,
60}
61
62impl CanonicalCtx {
63    fn canonicalize_nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
64        let mut out = Vec::with_capacity(nodes.len());
65        for node in nodes {
66            push_canonical_node(&mut out, self.canonicalize_node(node));
67        }
68        out
69    }
70
71    fn canonicalize_node(&mut self, node: &Node) -> Node {
72        match node {
73            Node::Let { name, value } => Node::Let {
74                name: name.clone(),
75                value: self.canonicalize_expr(value),
76            },
77            Node::Assign { name, value } => Node::Assign {
78                name: name.clone(),
79                value: self.canonicalize_expr(value),
80            },
81            Node::Store {
82                buffer,
83                index,
84                value,
85            } => Node::Store {
86                buffer: buffer.clone(),
87                index: self.canonicalize_expr(index),
88                value: self.canonicalize_expr(value),
89            },
90            Node::If {
91                cond,
92                then,
93                otherwise,
94            } => Node::If {
95                cond: self.canonicalize_expr(cond),
96                then: self.canonicalize_nodes(then),
97                otherwise: self.canonicalize_nodes(otherwise),
98            },
99            Node::Loop {
100                var,
101                from,
102                to,
103                body,
104            } => Node::Loop {
105                var: var.clone(),
106                from: self.canonicalize_expr(from),
107                to: self.canonicalize_expr(to),
108                body: self.canonicalize_nodes(body),
109            },
110            Node::Block(children) => Node::Block(self.canonicalize_nodes(children)),
111            Node::Region {
112                generator,
113                source_region,
114                body,
115            } => Node::Region {
116                generator: generator.clone(),
117                source_region: source_region.clone(),
118                body: Arc::new(self.canonicalize_nodes(body)),
119            },
120            Node::AsyncLoad {
121                source,
122                destination,
123                offset,
124                size,
125                tag,
126            } => Node::AsyncLoad {
127                source: source.clone(),
128                destination: destination.clone(),
129                offset: Box::new(self.canonicalize_expr(offset)),
130                size: Box::new(self.canonicalize_expr(size)),
131                tag: tag.clone(),
132            },
133            Node::AsyncStore {
134                source,
135                destination,
136                offset,
137                size,
138                tag,
139            } => Node::AsyncStore {
140                source: source.clone(),
141                destination: destination.clone(),
142                offset: Box::new(self.canonicalize_expr(offset)),
143                size: Box::new(self.canonicalize_expr(size)),
144                tag: tag.clone(),
145            },
146            Node::Trap { address, tag } => Node::Trap {
147                address: Box::new(self.canonicalize_expr(address)),
148                tag: tag.clone(),
149            },
150            Node::IndirectDispatch {
151                count_buffer,
152                count_offset,
153            } => Node::IndirectDispatch {
154                count_buffer: count_buffer.clone(),
155                count_offset: *count_offset,
156            },
157            Node::AsyncWait { tag } => Node::AsyncWait { tag: tag.clone() },
158            Node::Resume { tag } => Node::Resume { tag: tag.clone() },
159            Node::Return => Node::Return,
160            Node::Barrier { ordering } => Node::barrier_with_ordering(*ordering),
161            Node::Opaque(extension) => Node::Opaque(Arc::clone(extension)),
162        }
163    }
164
165    fn canonicalize_expr(&mut self, expr: &Expr) -> Expr {
166        match expr {
167            Expr::BinOp { op, left, right } => {
168                let mut left = self.canonicalize_expr(left);
169                let mut right = self.canonicalize_expr(right);
170                if should_swap_operands(op, &left, &right, &mut self.left_key, &mut self.right_key)
171                {
172                    std::mem::swap(&mut left, &mut right);
173                }
174                Expr::BinOp {
175                    op: *op,
176                    left: Box::new(left),
177                    right: Box::new(right),
178                }
179            }
180            Expr::UnOp { op, operand } => Expr::UnOp {
181                op: op.clone(),
182                operand: Box::new(self.canonicalize_expr(operand)),
183            },
184            Expr::Load { buffer, index } => Expr::Load {
185                buffer: buffer.clone(),
186                index: Box::new(self.canonicalize_expr(index)),
187            },
188            Expr::Call { op_id, args } => Expr::Call {
189                op_id: op_id.clone(),
190                args: args.iter().map(|arg| self.canonicalize_expr(arg)).collect(),
191            },
192            Expr::Select {
193                cond,
194                true_val,
195                false_val,
196            } => Expr::Select {
197                cond: Box::new(self.canonicalize_expr(cond)),
198                true_val: Box::new(self.canonicalize_expr(true_val)),
199                false_val: Box::new(self.canonicalize_expr(false_val)),
200            },
201            Expr::Cast { target, value } => Expr::Cast {
202                target: target.clone(),
203                value: Box::new(self.canonicalize_expr(value)),
204            },
205            Expr::Fma { a, b, c } => Expr::Fma {
206                a: Box::new(self.canonicalize_expr(a)),
207                b: Box::new(self.canonicalize_expr(b)),
208                c: Box::new(self.canonicalize_expr(c)),
209            },
210            Expr::Atomic {
211                op,
212                buffer,
213                index,
214                expected,
215                value,
216                ordering,
217            } => Expr::Atomic {
218                op: *op,
219                buffer: buffer.clone(),
220                index: Box::new(self.canonicalize_expr(index)),
221                expected: expected
222                    .as_ref()
223                    .map(|expr| Box::new(self.canonicalize_expr(expr))),
224                value: Box::new(self.canonicalize_expr(value)),
225                ordering: *ordering,
226            },
227            Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
228                cond: Box::new(self.canonicalize_expr(cond)),
229            },
230            Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
231                value: Box::new(self.canonicalize_expr(value)),
232                lane: Box::new(self.canonicalize_expr(lane)),
233            },
234            Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
235                value: Box::new(self.canonicalize_expr(value)),
236            },
237            other => other.clone(),
238        }
239    }
240}
241
242fn push_canonical_node(out: &mut Vec<Node>, node: Node) {
243    match node {
244        Node::Block(children) if can_splice_block(&children) => out.extend(children),
245        other => out.push(other),
246    }
247}
248
249fn can_splice_block(nodes: &[Node]) -> bool {
250    nodes.iter().all(|node| !matches!(node, Node::Let { .. }))
251}
252
253fn should_swap_operands(
254    op: &BinOp,
255    left: &Expr,
256    right: &Expr,
257    left_key: &mut Vec<u8>,
258    right_key: &mut Vec<u8>,
259) -> bool {
260    if !is_commutative_binop(op) {
261        return false;
262    }
263    match (is_literal(left), is_literal(right)) {
264        (true, false) => true,
265        (false, true) => false,
266        (true, true) => {
267            // Both literals: every commutative op is observably-safe
268            // to canonicalize because the literal pair folds to the
269            // same value regardless of order. The float-sensitivity
270            // contract (Add/Mul reassociation changes rounding) only
271            // applies when at least one operand is non-literal.
272            expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
273        }
274        (false, false) => {
275            can_sort_all_operands(op) && expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
276        }
277    }
278}
279
280fn expr_wire_key_cmp(
281    left: &Expr,
282    right: &Expr,
283    left_key: &mut Vec<u8>,
284    right_key: &mut Vec<u8>,
285) -> std::cmp::Ordering {
286    left_key.clear();
287    right_key.clear();
288    append_expr_wire_key(left_key, left);
289    append_expr_wire_key(right_key, right);
290    left_key.as_slice().cmp(right_key.as_slice())
291}
292
293fn append_expr_wire_key(key: &mut Vec<u8>, expr: &Expr) {
294    if let Err(error) = crate::serial::wire::encode::put_expr(key, expr) {
295        key.clear();
296        key.extend_from_slice(b"VYRE-CANONICAL-EXPR-WIRE-ERROR\0");
297        key.extend_from_slice(error.as_bytes());
298    }
299}
300
301fn is_commutative_binop(op: &BinOp) -> bool {
302    matches!(
303        op,
304        BinOp::Add
305            | BinOp::WrappingAdd
306            | BinOp::SaturatingAdd
307            | BinOp::Mul
308            | BinOp::SaturatingMul
309            | BinOp::BitAnd
310            | BinOp::BitOr
311            | BinOp::BitXor
312            | BinOp::Eq
313            | BinOp::Ne
314            | BinOp::And
315            | BinOp::Or
316            | BinOp::Min
317            | BinOp::Max
318            | BinOp::AbsDiff
319    )
320}
321
322fn can_sort_all_operands(op: &BinOp) -> bool {
323    // Ops whose operand swap is observably safe even when both
324    // operands are arbitrary non-literal expressions. Excludes Add /
325    // Mul because float reassociation changes rounding for non-literal
326    // chains; `should_swap_operands` handles the both-literal case
327    // separately so the canonical fingerprint still normalises
328    // `Add(1, 2)` vs `Add(2, 1)`.
329    matches!(
330        op,
331        BinOp::WrappingAdd
332            | BinOp::SaturatingAdd
333            | BinOp::SaturatingMul
334            | BinOp::BitAnd
335            | BinOp::BitOr
336            | BinOp::BitXor
337            | BinOp::Eq
338            | BinOp::Ne
339            | BinOp::And
340            | BinOp::Or
341            | BinOp::AbsDiff
342    )
343}
344
345fn is_literal(expr: &Expr) -> bool {
346    matches!(
347        expr,
348        Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_)
349    )
350}