1use std::sync::Arc;
2
3use crate::ir_inner::model::expr::Expr;
4use crate::ir_inner::model::node::Node;
5use crate::ir_inner::model::types::BinOp;
6
7use super::{meta::buffer_decl_canonical_key, BufferDecl, Program};
8
9impl Program {
10 #[must_use]
17 pub fn canonicalized(&self) -> Self {
18 let mut buffers = self.buffers().to_vec();
19 sort_buffers(&mut buffers);
20 let mut ctx = CanonicalCtx::default();
21 self.with_rewritten_entry(ctx.canonicalize_nodes(self.entry()))
22 .with_rewritten_buffers(buffers)
23 }
24
25 #[must_use]
32 pub fn canonical_wire_bytes(&self) -> Result<Vec<u8>, crate::error::Error> {
33 let canonical = self.canonicalized();
34 let stats = canonical.stats();
40 let estimate = 256
41 + stats.node_count.saturating_mul(48)
42 + canonical.buffers().len().saturating_mul(64);
43 let mut out = Vec::with_capacity(estimate);
44 crate::serial::wire::encode::to_wire_into(&canonical, &mut out)
45 .map_err(|message| crate::error::Error::WireFormatValidation { message })?;
46 Ok(out)
47 }
48
49 pub fn canonical_wire_hash(&self) -> Result<blake3::Hash, crate::error::Error> {
56 self.canonical_wire_bytes()
57 .map(|bytes| blake3::hash(&bytes))
58 }
59}
60
61fn sort_buffers(buffers: &mut [BufferDecl]) {
62 buffers.sort_by_cached_key(buffer_decl_canonical_key);
63}
64
65#[derive(Default)]
66struct CanonicalCtx {
67 left_key: Vec<u8>,
68 right_key: Vec<u8>,
69}
70
71impl CanonicalCtx {
72 fn canonicalize_nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
73 let mut out = Vec::with_capacity(nodes.len());
74 for node in nodes {
75 push_canonical_node(&mut out, self.canonicalize_node(node));
76 }
77 out
78 }
79
80 fn canonicalize_node(&mut self, node: &Node) -> Node {
81 match node {
82 Node::Let { name, value } => Node::Let {
83 name: name.clone(),
84 value: self.canonicalize_expr(value),
85 },
86 Node::Assign { name, value } => Node::Assign {
87 name: name.clone(),
88 value: self.canonicalize_expr(value),
89 },
90 Node::Store {
91 buffer,
92 index,
93 value,
94 } => Node::Store {
95 buffer: buffer.clone(),
96 index: self.canonicalize_expr(index),
97 value: self.canonicalize_expr(value),
98 },
99 Node::If {
100 cond,
101 then,
102 otherwise,
103 } => Node::If {
104 cond: self.canonicalize_expr(cond),
105 then: self.canonicalize_nodes(then),
106 otherwise: self.canonicalize_nodes(otherwise),
107 },
108 Node::Loop {
109 var,
110 from,
111 to,
112 body,
113 } => Node::Loop {
114 var: var.clone(),
115 from: self.canonicalize_expr(from),
116 to: self.canonicalize_expr(to),
117 body: self.canonicalize_nodes(body),
118 },
119 Node::Block(children) => Node::Block(self.canonicalize_nodes(children)),
120 Node::Region {
121 generator,
122 source_region,
123 body,
124 } => Node::Region {
125 generator: generator.clone(),
126 source_region: source_region.clone(),
127 body: Arc::new(self.canonicalize_nodes(body)),
128 },
129 Node::AsyncLoad {
130 source,
131 destination,
132 offset,
133 size,
134 tag,
135 } => Node::AsyncLoad {
136 source: source.clone(),
137 destination: destination.clone(),
138 offset: Box::new(self.canonicalize_expr(offset)),
139 size: Box::new(self.canonicalize_expr(size)),
140 tag: tag.clone(),
141 },
142 Node::AsyncStore {
143 source,
144 destination,
145 offset,
146 size,
147 tag,
148 } => Node::AsyncStore {
149 source: source.clone(),
150 destination: destination.clone(),
151 offset: Box::new(self.canonicalize_expr(offset)),
152 size: Box::new(self.canonicalize_expr(size)),
153 tag: tag.clone(),
154 },
155 Node::Trap { address, tag } => Node::Trap {
156 address: Box::new(self.canonicalize_expr(address)),
157 tag: tag.clone(),
158 },
159 Node::IndirectDispatch {
160 count_buffer,
161 count_offset,
162 } => Node::IndirectDispatch {
163 count_buffer: count_buffer.clone(),
164 count_offset: *count_offset,
165 },
166 Node::AllReduce { buffer, op, group } => Node::AllReduce {
167 buffer: buffer.clone(),
168 op: *op,
169 group: *group,
170 },
171 Node::AllGather {
172 input,
173 output,
174 group,
175 } => Node::AllGather {
176 input: input.clone(),
177 output: output.clone(),
178 group: *group,
179 },
180 Node::ReduceScatter {
181 input,
182 output,
183 op,
184 group,
185 } => Node::ReduceScatter {
186 input: input.clone(),
187 output: output.clone(),
188 op: *op,
189 group: *group,
190 },
191 Node::Broadcast {
192 buffer,
193 root,
194 group,
195 } => Node::Broadcast {
196 buffer: buffer.clone(),
197 root: *root,
198 group: *group,
199 },
200 Node::AsyncWait { tag } => Node::AsyncWait { tag: tag.clone() },
201 Node::Resume { tag } => Node::Resume { tag: tag.clone() },
202 Node::Return => Node::Return,
203 Node::Barrier { ordering } => Node::barrier_with_ordering(*ordering),
204 Node::Opaque(extension) => Node::Opaque(Arc::clone(extension)),
205 }
206 }
207
208 fn canonicalize_expr(&mut self, expr: &Expr) -> Expr {
209 match expr {
210 Expr::BinOp { op, left, right } => {
211 let mut left = self.canonicalize_expr(left);
212 let mut right = self.canonicalize_expr(right);
213 if should_swap_operands(*op, &left, &right, &mut self.left_key, &mut self.right_key)
214 {
215 std::mem::swap(&mut left, &mut right);
216 }
217 Expr::BinOp {
218 op: *op,
219 left: Box::new(left),
220 right: Box::new(right),
221 }
222 }
223 Expr::UnOp { op, operand } => Expr::UnOp {
224 op: op.clone(),
225 operand: Box::new(self.canonicalize_expr(operand)),
226 },
227 Expr::Load { buffer, index } => Expr::Load {
228 buffer: buffer.clone(),
229 index: Box::new(self.canonicalize_expr(index)),
230 },
231 Expr::Call { op_id, args } => Expr::Call {
232 op_id: op_id.clone(),
233 args: args.iter().map(|arg| self.canonicalize_expr(arg)).collect(),
234 },
235 Expr::Select {
236 cond,
237 true_val,
238 false_val,
239 } => Expr::Select {
240 cond: Box::new(self.canonicalize_expr(cond)),
241 true_val: Box::new(self.canonicalize_expr(true_val)),
242 false_val: Box::new(self.canonicalize_expr(false_val)),
243 },
244 Expr::Cast { target, value } => Expr::Cast {
245 target: target.clone(),
246 value: Box::new(self.canonicalize_expr(value)),
247 },
248 Expr::Fma { a, b, c } => Expr::Fma {
249 a: Box::new(self.canonicalize_expr(a)),
250 b: Box::new(self.canonicalize_expr(b)),
251 c: Box::new(self.canonicalize_expr(c)),
252 },
253 Expr::Atomic {
254 op,
255 buffer,
256 index,
257 expected,
258 value,
259 ordering,
260 } => Expr::Atomic {
261 op: *op,
262 buffer: buffer.clone(),
263 index: Box::new(self.canonicalize_expr(index)),
264 expected: expected
265 .as_ref()
266 .map(|expr| Box::new(self.canonicalize_expr(expr))),
267 value: Box::new(self.canonicalize_expr(value)),
268 ordering: *ordering,
269 },
270 Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
271 cond: Box::new(self.canonicalize_expr(cond)),
272 },
273 Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
274 value: Box::new(self.canonicalize_expr(value)),
275 lane: Box::new(self.canonicalize_expr(lane)),
276 },
277 Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
278 value: Box::new(self.canonicalize_expr(value)),
279 },
280 other => other.clone(),
281 }
282 }
283}
284
285fn push_canonical_node(out: &mut Vec<Node>, node: Node) {
286 match node {
287 Node::Block(children) if can_splice_block(&children) => out.extend(children),
288 other => out.push(other),
289 }
290}
291
292fn can_splice_block(nodes: &[Node]) -> bool {
293 nodes.iter().all(|node| !matches!(node, Node::Let { .. }))
294}
295
296fn should_swap_operands(
297 op: BinOp,
298 left: &Expr,
299 right: &Expr,
300 left_key: &mut Vec<u8>,
301 right_key: &mut Vec<u8>,
302) -> bool {
303 if !is_commutative_binop(op) {
304 return false;
305 }
306 match (is_literal(left), is_literal(right)) {
307 (true, false) => true,
308 (false, true) => false,
309 (true, true) => {
310 expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
316 }
317 (false, false) => {
318 can_sort_all_operands(op) && expr_wire_key_cmp(left, right, left_key, right_key).is_gt()
319 }
320 }
321}
322
323fn expr_wire_key_cmp(
324 left: &Expr,
325 right: &Expr,
326 left_key: &mut Vec<u8>,
327 right_key: &mut Vec<u8>,
328) -> std::cmp::Ordering {
329 left_key.clear();
330 right_key.clear();
331 append_expr_wire_key(left_key, left);
332 append_expr_wire_key(right_key, right);
333 left_key.as_slice().cmp(right_key.as_slice())
334}
335
336fn append_expr_wire_key(key: &mut Vec<u8>, expr: &Expr) {
337 if let Err(error) = crate::serial::wire::encode::put_expr(key, expr) {
338 key.clear();
339 key.extend_from_slice(b"VYRE-CANONICAL-EXPR-WIRE-ERROR\0");
340 key.extend_from_slice(error.as_bytes());
341 }
342}
343
344fn is_commutative_binop(op: BinOp) -> bool {
345 matches!(
346 op,
347 BinOp::Add
348 | BinOp::WrappingAdd
349 | BinOp::SaturatingAdd
350 | BinOp::Mul
351 | BinOp::SaturatingMul
352 | BinOp::BitAnd
353 | BinOp::BitOr
354 | BinOp::BitXor
355 | BinOp::Eq
356 | BinOp::Ne
357 | BinOp::And
358 | BinOp::Or
359 | BinOp::Min
360 | BinOp::Max
361 | BinOp::AbsDiff
362 )
363}
364
365fn can_sort_all_operands(op: BinOp) -> bool {
366 matches!(
373 op,
374 BinOp::WrappingAdd
375 | BinOp::SaturatingAdd
376 | BinOp::SaturatingMul
377 | BinOp::BitAnd
378 | BinOp::BitOr
379 | BinOp::BitXor
380 | BinOp::Eq
381 | BinOp::Ne
382 | BinOp::And
383 | BinOp::Or
384 | BinOp::AbsDiff
385 )
386}
387
388fn is_literal(expr: &Expr) -> bool {
389 matches!(
390 expr,
391 Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_)
392 )
393}