1use std::sync::Arc;
2
3use crate::ir_inner::model::expr::Expr;
4use crate::ir_inner::model::node::Node;
5use crate::ir_inner::model::types::BinOp;
6
7use super::{meta::buffer_decl_canonical_key, BufferDecl, Program};
8
9impl Program {
10 #[must_use]
17 pub fn canonicalized(&self) -> Self {
18 let mut buffers = self.buffers().to_vec();
19 sort_buffers(&mut buffers);
20 let mut ctx = CanonicalCtx::default();
21 self.with_rewritten_entry(ctx.canonicalize_nodes(self.entry()))
22 .with_rewritten_buffers(buffers)
23 }
24
25 #[must_use]
32 pub fn canonical_wire_bytes(&self) -> Result<Vec<u8>, crate::error::Error> {
33 let canonical = self.canonicalized();
34 let mut out = Vec::new();
35 crate::serial::wire::encode::to_wire_into(&canonical, &mut out)
36 .map_err(|message| crate::error::Error::WireFormatValidation { message })?;
37 Ok(out)
38 }
39
40 pub fn canonical_wire_hash(&self) -> Result<blake3::Hash, crate::error::Error> {
47 self.canonical_wire_bytes()
48 .map(|bytes| blake3::hash(&bytes))
49 }
50}
51
52fn sort_buffers(buffers: &mut [BufferDecl]) {
53 buffers.sort_by_cached_key(buffer_decl_canonical_key);
54}
55
56#[derive(Default)]
57struct CanonicalCtx {
58 left_key: Vec<u8>,
59 right_key: Vec<u8>,
60}
61
62impl CanonicalCtx {
63 fn canonicalize_nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
64 let mut out = Vec::with_capacity(nodes.len());
65 for node in nodes {
66 push_canonical_node(&mut out, self.canonicalize_node(node));
67 }
68 out
69 }
70
71 fn canonicalize_node(&mut self, node: &Node) -> Node {
72 match node {
73 Node::Let { name, value } => Node::Let {
74 name: name.clone(),
75 value: self.canonicalize_expr(value),
76 },
77 Node::Assign { name, value } => Node::Assign {
78 name: name.clone(),
79 value: self.canonicalize_expr(value),
80 },
81 Node::Store {
82 buffer,
83 index,
84 value,
85 } => Node::Store {
86 buffer: buffer.clone(),
87 index: self.canonicalize_expr(index),
88 value: self.canonicalize_expr(value),
89 },
90 Node::If {
91 cond,
92 then,
93 otherwise,
94 } => Node::If {
95 cond: self.canonicalize_expr(cond),
96 then: self.canonicalize_nodes(then),
97 otherwise: self.canonicalize_nodes(otherwise),
98 },
99 Node::Loop {
100 var,
101 from,
102 to,
103 body,
104 } => Node::Loop {
105 var: var.clone(),
106 from: self.canonicalize_expr(from),
107 to: self.canonicalize_expr(to),
108 body: self.canonicalize_nodes(body),
109 },
110 Node::Block(children) => Node::Block(self.canonicalize_nodes(children)),
111 Node::Region {
112 generator,
113 source_region,
114 body,
115 } => Node::Region {
116 generator: generator.clone(),
117 source_region: source_region.clone(),
118 body: Arc::new(self.canonicalize_nodes(body)),
119 },
120 Node::AsyncLoad {
121 source,
122 destination,
123 offset,
124 size,
125 tag,
126 } => Node::AsyncLoad {
127 source: source.clone(),
128 destination: destination.clone(),
129 offset: Box::new(self.canonicalize_expr(offset)),
130 size: Box::new(self.canonicalize_expr(size)),
131 tag: tag.clone(),
132 },
133 Node::AsyncStore {
134 source,
135 destination,
136 offset,
137 size,
138 tag,
139 } => Node::AsyncStore {
140 source: source.clone(),
141 destination: destination.clone(),
142 offset: Box::new(self.canonicalize_expr(offset)),
143 size: Box::new(self.canonicalize_expr(size)),
144 tag: tag.clone(),
145 },
146 Node::Trap { address, tag } => Node::Trap {
147 address: Box::new(self.canonicalize_expr(address)),
148 tag: tag.clone(),
149 },
150 Node::IndirectDispatch {
151 count_buffer,
152 count_offset,
153 } => Node::IndirectDispatch {
154 count_buffer: count_buffer.clone(),
155 count_offset: *count_offset,
156 },
157 Node::AsyncWait { tag } => Node::AsyncWait { tag: tag.clone() },
158 Node::Resume { tag } => Node::Resume { tag: tag.clone() },
159 Node::Return => Node::Return,
160 Node::Barrier { ordering } => Node::barrier_with_ordering(*ordering),
161 Node::Opaque(extension) => Node::Opaque(Arc::clone(extension)),
162 }
163 }
164
165 fn canonicalize_expr(&mut self, expr: &Expr) -> Expr {
166 match expr {
167 Expr::BinOp { op, left, right } => {
168 let mut left = self.canonicalize_expr(left);
169 let mut right = self.canonicalize_expr(right);
170 if should_swap_operands(op, &left, &right, &mut self.left_key, &mut self.right_key)
171 {
172 std::mem::swap(&mut left, &mut right);
173 }
174 Expr::BinOp {
175 op: *op,
176 left: Box::new(left),
177 right: Box::new(right),
178 }
179 }
180 Expr::UnOp { op, operand } => Expr::UnOp {
181 op: op.clone(),
182 operand: Box::new(self.canonicalize_expr(operand)),
183 },
184 Expr::Load { buffer, index } => Expr::Load {
185 buffer: buffer.clone(),
186 index: Box::new(self.canonicalize_expr(index)),
187 },
188 Expr::Call { op_id, args } => Expr::Call {
189 op_id: op_id.clone(),
190 args: args.iter().map(|arg| self.canonicalize_expr(arg)).collect(),
191 },
192 Expr::Select {
193 cond,
194 true_val,
195 false_val,
196 } => Expr::Select {
197 cond: Box::new(self.canonicalize_expr(cond)),
198 true_val: Box::new(self.canonicalize_expr(true_val)),
199 false_val: Box::new(self.canonicalize_expr(false_val)),
200 },
201 Expr::Cast { target, value } => Expr::Cast {
202 target: target.clone(),
203 value: Box::new(self.canonicalize_expr(value)),
204 },
205 Expr::Fma { a, b, c } => Expr::Fma {
206 a: Box::new(self.canonicalize_expr(a)),
207 b: Box::new(self.canonicalize_expr(b)),
208 c: Box::new(self.canonicalize_expr(c)),
209 },
210 Expr::Atomic {
211 op,
212 buffer,
213 index,
214 expected,
215 value,
216 ordering,
217 } => Expr::Atomic {
218 op: *op,
219 buffer: buffer.clone(),
220 index: Box::new(self.canonicalize_expr(index)),
221 expected: expected
222 .as_ref()
223 .map(|expr| Box::new(self.canonicalize_expr(expr))),
224 value: Box::new(self.canonicalize_expr(value)),
225 ordering: *ordering,
226 },
227 Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
228 cond: Box::new(self.canonicalize_expr(cond)),
229 },
230 Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
231 value: Box::new(self.canonicalize_expr(value)),
232 lane: Box::new(self.canonicalize_expr(lane)),
233 },
234 Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
235 value: Box::new(self.canonicalize_expr(value)),
236 },
237 other => other.clone(),
238 }
239 }
240}
241
242fn push_canonical_node(out: &mut Vec<Node>, node: Node) {
243 match node {
244 Node::Block(children) if can_splice_block(&children) => out.extend(children),
245 other => out.push(other),
246 }
247}
248
249fn can_splice_block(nodes: &[Node]) -> bool {
250 nodes.iter().all(|node| !matches!(node, Node::Let { .. }))
251}
252
253fn should_swap_operands(
254 op: &BinOp,
255 left: &Expr,
256 right: &Expr,
257 left_key: &mut Vec<u8>,
258 right_key: &mut Vec<u8>,
259) -> bool {
260 if !is_commutative_binop(op) {
261 return false;
262 }
263 match (is_literal(left), is_literal(right)) {
264 (true, false) => true,
265 (false, true) => false,
266 (true, true) => {
267 expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
273 }
274 (false, false) => {
275 can_sort_all_operands(op) && expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
276 }
277 }
278}
279
280fn expr_wire_key_cmp(
281 left: &Expr,
282 right: &Expr,
283 left_key: &mut Vec<u8>,
284 right_key: &mut Vec<u8>,
285) -> std::cmp::Ordering {
286 left_key.clear();
287 right_key.clear();
288 append_expr_wire_key(left_key, left);
289 append_expr_wire_key(right_key, right);
290 left_key.as_slice().cmp(right_key.as_slice())
291}
292
293fn append_expr_wire_key(key: &mut Vec<u8>, expr: &Expr) {
294 if let Err(error) = crate::serial::wire::encode::put_expr(key, expr) {
295 key.clear();
296 key.extend_from_slice(b"VYRE-CANONICAL-EXPR-WIRE-ERROR\0");
297 key.extend_from_slice(error.as_bytes());
298 }
299}
300
301fn is_commutative_binop(op: &BinOp) -> bool {
302 matches!(
303 op,
304 BinOp::Add
305 | BinOp::WrappingAdd
306 | BinOp::SaturatingAdd
307 | BinOp::Mul
308 | BinOp::SaturatingMul
309 | BinOp::BitAnd
310 | BinOp::BitOr
311 | BinOp::BitXor
312 | BinOp::Eq
313 | BinOp::Ne
314 | BinOp::And
315 | BinOp::Or
316 | BinOp::Min
317 | BinOp::Max
318 | BinOp::AbsDiff
319 )
320}
321
322fn can_sort_all_operands(op: &BinOp) -> bool {
323 matches!(
330 op,
331 BinOp::WrappingAdd
332 | BinOp::SaturatingAdd
333 | BinOp::SaturatingMul
334 | BinOp::BitAnd
335 | BinOp::BitOr
336 | BinOp::BitXor
337 | BinOp::Eq
338 | BinOp::Ne
339 | BinOp::And
340 | BinOp::Or
341 | BinOp::AbsDiff
342 )
343}
344
345fn is_literal(expr: &Expr) -> bool {
346 matches!(
347 expr,
348 Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_)
349 )
350}