Skip to main content

vyre_foundation/ir_inner/model/program/
canonical.rs

1use std::sync::Arc;
2
3use crate::ir_inner::model::expr::Expr;
4use crate::ir_inner::model::node::Node;
5use crate::ir_inner::model::types::BinOp;
6
7use super::{meta::buffer_decl_canonical_key, BufferDecl, Program};
8
9impl Program {
10    /// Return the canonical IR shape used for security-sensitive cache keys.
11    ///
12    /// Canonicalization preserves executable semantics while removing
13    /// authoring-order noise: buffer declarations are sorted by their stable
14    /// wire key, commutative expression operands are normalized, and `Block`
15    /// wrappers that do not own local bindings are flattened.
16    #[must_use]
17    pub fn canonicalized(&self) -> Self {
18        let mut buffers = self.buffers().to_vec();
19        sort_buffers(&mut buffers);
20        let mut ctx = CanonicalCtx::default();
21        self.with_rewritten_entry(ctx.canonicalize_nodes(self.entry()))
22            .with_rewritten_buffers(buffers)
23    }
24
25    /// Serialize the canonical IR shape into stable VIR0 wire bytes.
26    ///
27    /// # Errors
28    ///
29    /// Returns the same wire-format validation errors as [`Self::to_wire`],
30    /// but after canonical normalization has been applied.
31    #[must_use]
32    pub fn canonical_wire_bytes(&self) -> Result<Vec<u8>, crate::error::Error> {
33        let canonical = self.canonicalized();
34        // Pre-size: VIR0 wire encoding lands in the ballpark of ~32
35        // bytes per IR node + a fixed program header. Over-sizing is
36        // free at this stage and avoids the typical 4-7 reallocations
37        // a fresh Vec<u8> would do while the encoder pushes header
38        // tags + buffer table + node tree.
39        let stats = canonical.stats();
40        let estimate = 256
41            + stats.node_count.saturating_mul(48)
42            + canonical.buffers().len().saturating_mul(64);
43        let mut out = Vec::with_capacity(estimate);
44        crate::serial::wire::encode::to_wire_into(&canonical, &mut out)
45            .map_err(|message| crate::error::Error::WireFormatValidation { message })?;
46        Ok(out)
47    }
48
49    /// BLAKE3 digest of [`Self::canonical_wire_bytes`].
50    ///
51    /// # Errors
52    ///
53    /// Returns a wire-format validation error if the canonical program cannot
54    /// be represented by the current VIR0 encoder.
55    pub fn canonical_wire_hash(&self) -> Result<blake3::Hash, crate::error::Error> {
56        self.canonical_wire_bytes()
57            .map(|bytes| blake3::hash(&bytes))
58    }
59}
60
61fn sort_buffers(buffers: &mut [BufferDecl]) {
62    buffers.sort_by_cached_key(buffer_decl_canonical_key);
63}
64
65#[derive(Default)]
66struct CanonicalCtx {
67    left_key: Vec<u8>,
68    right_key: Vec<u8>,
69}
70
71impl CanonicalCtx {
72    fn canonicalize_nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
73        let mut out = Vec::with_capacity(nodes.len());
74        for node in nodes {
75            push_canonical_node(&mut out, self.canonicalize_node(node));
76        }
77        out
78    }
79
80    fn canonicalize_node(&mut self, node: &Node) -> Node {
81        match node {
82            Node::Let { name, value } => Node::Let {
83                name: name.clone(),
84                value: self.canonicalize_expr(value),
85            },
86            Node::Assign { name, value } => Node::Assign {
87                name: name.clone(),
88                value: self.canonicalize_expr(value),
89            },
90            Node::Store {
91                buffer,
92                index,
93                value,
94            } => Node::Store {
95                buffer: buffer.clone(),
96                index: self.canonicalize_expr(index),
97                value: self.canonicalize_expr(value),
98            },
99            Node::If {
100                cond,
101                then,
102                otherwise,
103            } => Node::If {
104                cond: self.canonicalize_expr(cond),
105                then: self.canonicalize_nodes(then),
106                otherwise: self.canonicalize_nodes(otherwise),
107            },
108            Node::Loop {
109                var,
110                from,
111                to,
112                body,
113            } => Node::Loop {
114                var: var.clone(),
115                from: self.canonicalize_expr(from),
116                to: self.canonicalize_expr(to),
117                body: self.canonicalize_nodes(body),
118            },
119            Node::Block(children) => Node::Block(self.canonicalize_nodes(children)),
120            Node::Region {
121                generator,
122                source_region,
123                body,
124            } => Node::Region {
125                generator: generator.clone(),
126                source_region: source_region.clone(),
127                body: Arc::new(self.canonicalize_nodes(body)),
128            },
129            Node::AsyncLoad {
130                source,
131                destination,
132                offset,
133                size,
134                tag,
135            } => Node::AsyncLoad {
136                source: source.clone(),
137                destination: destination.clone(),
138                offset: Box::new(self.canonicalize_expr(offset)),
139                size: Box::new(self.canonicalize_expr(size)),
140                tag: tag.clone(),
141            },
142            Node::AsyncStore {
143                source,
144                destination,
145                offset,
146                size,
147                tag,
148            } => Node::AsyncStore {
149                source: source.clone(),
150                destination: destination.clone(),
151                offset: Box::new(self.canonicalize_expr(offset)),
152                size: Box::new(self.canonicalize_expr(size)),
153                tag: tag.clone(),
154            },
155            Node::Trap { address, tag } => Node::Trap {
156                address: Box::new(self.canonicalize_expr(address)),
157                tag: tag.clone(),
158            },
159            Node::IndirectDispatch {
160                count_buffer,
161                count_offset,
162            } => Node::IndirectDispatch {
163                count_buffer: count_buffer.clone(),
164                count_offset: *count_offset,
165            },
166            Node::AllReduce { buffer, op, group } => Node::AllReduce {
167                buffer: buffer.clone(),
168                op: *op,
169                group: *group,
170            },
171            Node::AllGather {
172                input,
173                output,
174                group,
175            } => Node::AllGather {
176                input: input.clone(),
177                output: output.clone(),
178                group: *group,
179            },
180            Node::ReduceScatter {
181                input,
182                output,
183                op,
184                group,
185            } => Node::ReduceScatter {
186                input: input.clone(),
187                output: output.clone(),
188                op: *op,
189                group: *group,
190            },
191            Node::Broadcast {
192                buffer,
193                root,
194                group,
195            } => Node::Broadcast {
196                buffer: buffer.clone(),
197                root: *root,
198                group: *group,
199            },
200            Node::AsyncWait { tag } => Node::AsyncWait { tag: tag.clone() },
201            Node::Resume { tag } => Node::Resume { tag: tag.clone() },
202            Node::Return => Node::Return,
203            Node::Barrier { ordering } => Node::barrier_with_ordering(*ordering),
204            Node::Opaque(extension) => Node::Opaque(Arc::clone(extension)),
205        }
206    }
207
208    fn canonicalize_expr(&mut self, expr: &Expr) -> Expr {
209        match expr {
210            Expr::BinOp { op, left, right } => {
211                let mut left = self.canonicalize_expr(left);
212                let mut right = self.canonicalize_expr(right);
213                if should_swap_operands(*op, &left, &right, &mut self.left_key, &mut self.right_key)
214                {
215                    std::mem::swap(&mut left, &mut right);
216                }
217                Expr::BinOp {
218                    op: *op,
219                    left: Box::new(left),
220                    right: Box::new(right),
221                }
222            }
223            Expr::UnOp { op, operand } => Expr::UnOp {
224                op: op.clone(),
225                operand: Box::new(self.canonicalize_expr(operand)),
226            },
227            Expr::Load { buffer, index } => Expr::Load {
228                buffer: buffer.clone(),
229                index: Box::new(self.canonicalize_expr(index)),
230            },
231            Expr::Call { op_id, args } => Expr::Call {
232                op_id: op_id.clone(),
233                args: args.iter().map(|arg| self.canonicalize_expr(arg)).collect(),
234            },
235            Expr::Select {
236                cond,
237                true_val,
238                false_val,
239            } => Expr::Select {
240                cond: Box::new(self.canonicalize_expr(cond)),
241                true_val: Box::new(self.canonicalize_expr(true_val)),
242                false_val: Box::new(self.canonicalize_expr(false_val)),
243            },
244            Expr::Cast { target, value } => Expr::Cast {
245                target: target.clone(),
246                value: Box::new(self.canonicalize_expr(value)),
247            },
248            Expr::Fma { a, b, c } => Expr::Fma {
249                a: Box::new(self.canonicalize_expr(a)),
250                b: Box::new(self.canonicalize_expr(b)),
251                c: Box::new(self.canonicalize_expr(c)),
252            },
253            Expr::Atomic {
254                op,
255                buffer,
256                index,
257                expected,
258                value,
259                ordering,
260            } => Expr::Atomic {
261                op: *op,
262                buffer: buffer.clone(),
263                index: Box::new(self.canonicalize_expr(index)),
264                expected: expected
265                    .as_ref()
266                    .map(|expr| Box::new(self.canonicalize_expr(expr))),
267                value: Box::new(self.canonicalize_expr(value)),
268                ordering: *ordering,
269            },
270            Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
271                cond: Box::new(self.canonicalize_expr(cond)),
272            },
273            Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
274                value: Box::new(self.canonicalize_expr(value)),
275                lane: Box::new(self.canonicalize_expr(lane)),
276            },
277            Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
278                value: Box::new(self.canonicalize_expr(value)),
279            },
280            other => other.clone(),
281        }
282    }
283}
284
285fn push_canonical_node(out: &mut Vec<Node>, node: Node) {
286    match node {
287        Node::Block(children) if can_splice_block(&children) => out.extend(children),
288        other => out.push(other),
289    }
290}
291
292fn can_splice_block(nodes: &[Node]) -> bool {
293    nodes.iter().all(|node| !matches!(node, Node::Let { .. }))
294}
295
296fn should_swap_operands(
297    op: BinOp,
298    left: &Expr,
299    right: &Expr,
300    left_key: &mut Vec<u8>,
301    right_key: &mut Vec<u8>,
302) -> bool {
303    if !is_commutative_binop(op) {
304        return false;
305    }
306    match (is_literal(left), is_literal(right)) {
307        (true, false) => true,
308        (false, true) => false,
309        (true, true) => {
310            // Both literals: every commutative op is observably-safe
311            // to canonicalize because the literal pair folds to the
312            // same value regardless of order. The float-sensitivity
313            // contract (Add/Mul reassociation changes rounding) only
314            // applies when at least one operand is non-literal.
315            expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
316        }
317        (false, false) => {
318            can_sort_all_operands(op) && expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
319        }
320    }
321}
322
323fn expr_wire_key_cmp(
324    left: &Expr,
325    right: &Expr,
326    left_key: &mut Vec<u8>,
327    right_key: &mut Vec<u8>,
328) -> std::cmp::Ordering {
329    left_key.clear();
330    right_key.clear();
331    append_expr_wire_key(left_key, left);
332    append_expr_wire_key(right_key, right);
333    left_key.as_slice().cmp(right_key.as_slice())
334}
335
336fn append_expr_wire_key(key: &mut Vec<u8>, expr: &Expr) {
337    if let Err(error) = crate::serial::wire::encode::put_expr(key, expr) {
338        key.clear();
339        key.extend_from_slice(b"VYRE-CANONICAL-EXPR-WIRE-ERROR\0");
340        key.extend_from_slice(error.as_bytes());
341    }
342}
343
344fn is_commutative_binop(op: BinOp) -> bool {
345    matches!(
346        op,
347        BinOp::Add
348            | BinOp::WrappingAdd
349            | BinOp::SaturatingAdd
350            | BinOp::Mul
351            | BinOp::SaturatingMul
352            | BinOp::BitAnd
353            | BinOp::BitOr
354            | BinOp::BitXor
355            | BinOp::Eq
356            | BinOp::Ne
357            | BinOp::And
358            | BinOp::Or
359            | BinOp::Min
360            | BinOp::Max
361            | BinOp::AbsDiff
362    )
363}
364
365fn can_sort_all_operands(op: BinOp) -> bool {
366    // Ops whose operand swap is observably safe even when both
367    // operands are arbitrary non-literal expressions. Excludes Add /
368    // Mul because float reassociation changes rounding for non-literal
369    // chains; `should_swap_operands` handles the both-literal case
370    // separately so the canonical fingerprint still normalises
371    // `Add(1, 2)` vs `Add(2, 1)`.
372    matches!(
373        op,
374        BinOp::WrappingAdd
375            | BinOp::SaturatingAdd
376            | BinOp::SaturatingMul
377            | BinOp::BitAnd
378            | BinOp::BitOr
379            | BinOp::BitXor
380            | BinOp::Eq
381            | BinOp::Ne
382            | BinOp::And
383            | BinOp::Or
384            | BinOp::AbsDiff
385    )
386}
387
388fn is_literal(expr: &Expr) -> bool {
389    matches!(
390        expr,
391        Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_)
392    )
393}