xdy/optimizer/commutation.rs
1//! # Commutation
2//!
3//! Commutative operations can reorder their operands without changing the
4//! meaning or result. Both addition and multiplication are commutative, so
5//! reordering the operands can create additional opportunities for folding
6//! constants and reducing strength. This pass reorders the operands of
7//! commutative operations so that immediates appear before registers. The
8//! constant commuter requires the function to be in static single assignment
9//! (SSA) form.
10
11use std::{
12 cmp::{max, min},
13 collections::{BTreeSet, HashMap}
14};
15
16use crate::{
17 Add, AddressingMode, CanVisitInstructions, DependencyAnalyzer, Div,
18 DropHighest, DropLowest, Exp, Instruction, InstructionVisitor, Mod, Mul,
19 Neg, Return, RollCustomDice, RollRange, RollStandardDice, Sub,
20 SumRollingRecord
21};
22
23////////////////////////////////////////////////////////////////////////////////
24// Commutation. //
25////////////////////////////////////////////////////////////////////////////////
26
27/// A commuter that shuffles the immediate operands of commutative instructions
28/// to the front in order to simplify constant folding and strength reduction.
29/// The standard optimizer applies this pass to a [function](Function) before
30/// folding constants or reducing strength.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ConstantCommuter<'inst>
33{
34 /// The instructions to organize.
35 instructions: &'inst [Instruction],
36
37 /// The dependency analyzer.
38 analyzer: DependencyAnalyzer<'inst>,
39
40 /// An arrangement of the instructions of a function body into commutative
41 /// groups. Each instruction is branded with a group number, and all
42 /// instructions sharing the same group number are commutative with respect
43 /// to each other.
44 groups: HashMap<Instruction, usize>,
45
46 /// The next group number to assign.
47 next_group: usize,
48
49 /// Replacement instructions, indexed by original instructions.
50 replacements: HashMap<Instruction, Instruction>
51}
52
53impl<'inst> ConstantCommuter<'inst>
54{
55 /// Construct a new instruction organizer.
56 ///
57 /// # Parameters
58 /// - `instructions`: The instructions to organize.
59 ///
60 /// # Returns
61 /// The new instruction organizer.
62 pub fn commute(instructions: &'inst [Instruction]) -> Vec<Instruction>
63 {
64 let mut commuter = Self {
65 instructions,
66 analyzer: DependencyAnalyzer::analyze(instructions),
67 groups: HashMap::new(),
68 next_group: 0,
69 replacements: HashMap::new()
70 };
71 // Visit each instruction in the function body to organize them into
72 // commutative groups.
73 for instruction in instructions
74 {
75 instruction.visit(&mut commuter).unwrap();
76 }
77 // For each commutative group, rewrite the instructions to shuffle their
78 // immediate operands to the front of the group. Use the first
79 // instruction of a group to classify the type of the group.
80 let groups = commuter.groups.values().copied().collect::<BTreeSet<_>>();
81 for group in groups
82 {
83 let representative = commuter
84 .instructions
85 .iter()
86 .find(|inst| commuter.groups.get(inst) == Some(&group))
87 .unwrap();
88 match representative
89 {
90 Instruction::DropLowest(_) => commuter
91 .rewrite_commutative_group(group, |dest, srcs| {
92 DropLowest {
93 dest: dest.try_into().unwrap(),
94 count: srcs[0]
95 }
96 }),
97 Instruction::DropHighest(_) => commuter
98 .rewrite_commutative_group(group, |dest, srcs| {
99 DropHighest {
100 dest: dest.try_into().unwrap(),
101 count: srcs[0]
102 }
103 }),
104 Instruction::Add(_) =>
105 {
106 commuter.rewrite_commutative_group(group, |dest, srcs| {
107 Add {
108 dest: dest.try_into().unwrap(),
109 op1: srcs[0],
110 op2: srcs[1]
111 }
112 })
113 },
114 Instruction::Sub(_) => commuter
115 .rewrite_nearly_commutative_group(
116 group,
117 |dest, srcs| Add {
118 dest: dest.try_into().unwrap(),
119 op1: srcs[0],
120 op2: srcs[1]
121 },
122 |dest, srcs| Sub {
123 dest: dest.try_into().unwrap(),
124 op1: srcs[0],
125 op2: srcs[1]
126 }
127 ),
128 Instruction::Mul(_) =>
129 {
130 commuter.rewrite_commutative_group(group, |dest, srcs| {
131 Mul {
132 dest: dest.try_into().unwrap(),
133 op1: srcs[0],
134 op2: srcs[1]
135 }
136 })
137 },
138 Instruction::Div(_) => commuter
139 .rewrite_nearly_commutative_group(
140 group,
141 |dest, srcs| Mul {
142 dest: dest.try_into().unwrap(),
143 op1: srcs[0],
144 op2: srcs[1]
145 },
146 |dest, srcs| Div {
147 dest: dest.try_into().unwrap(),
148 op1: srcs[0],
149 op2: srcs[1]
150 }
151 ),
152 _ =>
153 {}
154 }
155 }
156 // Answer the replacement function body.
157 instructions
158 .iter()
159 .map(|inst| commuter.replacements.get(inst).unwrap_or(inst))
160 .cloned()
161 .collect()
162 }
163}
164
165impl InstructionVisitor<()> for ConstantCommuter<'_>
166{
167 fn visit_roll_range(&mut self, _inst: &RollRange) -> Result<(), ()>
168 {
169 Ok(())
170 }
171
172 fn visit_roll_standard_dice(
173 &mut self,
174 _inst: &RollStandardDice
175 ) -> Result<(), ()>
176 {
177 Ok(())
178 }
179
180 fn visit_roll_custom_dice(
181 &mut self,
182 _inst: &RollCustomDice
183 ) -> Result<(), ()>
184 {
185 Ok(())
186 }
187
188 fn visit_drop_lowest(&mut self, _inst: &DropLowest) -> Result<(), ()>
189 {
190 Ok(())
191 }
192
193 fn visit_drop_highest(&mut self, _inst: &DropHighest) -> Result<(), ()>
194 {
195 Ok(())
196 }
197
198 fn visit_sum_rolling_record(
199 &mut self,
200 _inst: &SumRollingRecord
201 ) -> Result<(), ()>
202 {
203 Ok(())
204 }
205
206 fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
207 {
208 self.organize(*inst, |inst| matches!(inst, Instruction::Add(_)))
209 }
210
211 fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
212 {
213 // Subtraction is not commutative, but a chain of subtractions can be
214 // rewritten as a single subtraction with the leading operand and the
215 // sum of the subtrahends.
216 self.organize(*inst, |inst| matches!(inst, Instruction::Sub(_)))
217 }
218
219 fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
220 {
221 self.organize(*inst, |inst| matches!(inst, Instruction::Mul(_)))
222 }
223
224 fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
225 {
226 // Division is not commutative, but a chain of divisions can be
227 // rewritten as a single division with the leading operand and the
228 // product of the divisors.
229 self.organize(*inst, |inst| matches!(inst, Instruction::Div(_)))
230 }
231
232 fn visit_mod(&mut self, _inst: &Mod) -> Result<(), ()> { Ok(()) }
233
234 fn visit_exp(&mut self, _inst: &Exp) -> Result<(), ()> { Ok(()) }
235
236 fn visit_neg(&mut self, _inst: &Neg) -> Result<(), ()> { Ok(()) }
237
238 fn visit_return(&mut self, _inst: &Return) -> Result<(), ()> { Ok(()) }
239}
240
241impl ConstantCommuter<'_>
242{
243 /// Place the specified commutative instruction into a group, using the
244 /// supplied filter to recognize other instructions of the same type.
245 ///
246 /// # Parameters
247 /// - `inst`: The instruction to organize.
248 /// - `filter`: A function that answers `true` if the instruction is of the
249 /// same type as the one being organized.
250 fn organize(
251 &mut self,
252 inst: impl Into<Instruction>,
253 filter: impl Fn(&Instruction) -> bool
254 ) -> Result<(), ()>
255 {
256 let inst = inst.into();
257 // Obtain our group, creating it if necessary.
258 let group = self.group_id(&inst);
259 // Merge each reader's group with ours if the reader is the same type of
260 // instruction as us. This reduces the total number of expensive merges.
261 for reader in self
262 .analyzer
263 .readers()
264 .get(&inst.destination().unwrap())
265 .unwrap()
266 .clone()
267 {
268 let reader = &self.instructions[reader.0];
269 if filter(reader)
270 {
271 let reader_group = self.group_id(&reader.clone());
272 self.merge_group_ids(group, reader_group);
273 }
274 }
275 Ok(())
276 }
277
278 /// Answer the identifier of the group to which the specified instruction
279 /// belongs, creating a new group and branding the instruction if
280 /// necessary.
281 ///
282 /// # Parameters
283 /// - `inst`: The instruction.
284 ///
285 /// # Returns
286 /// The requested group identifier.
287 fn group_id(&mut self, inst: &Instruction) -> usize
288 {
289 match self.groups.get(inst)
290 {
291 Some(index) => *index,
292 None =>
293 {
294 let index = self.next_group;
295 self.groups.insert(inst.clone(), index);
296 self.next_group += 1;
297 index
298 }
299 }
300 }
301
302 /// Merge two groups of commutative instructions.
303 ///
304 /// # Parameters
305 /// - `first`: The first group identifier.
306 /// - `second`: The second group identifier.
307 fn merge_group_ids(&mut self, first: usize, second: usize)
308 {
309 if first != second
310 {
311 // We prefer to keep the lower group number.
312 let min = min(first, second);
313 let max = max(first, second);
314 // For each instruction in the second group, move it to the first
315 // group.
316 self.groups.iter_mut().for_each(|(_, group)| {
317 if *group == max
318 {
319 *group = min;
320 }
321 });
322 assert!(self.groups.values().all(|&group| group != max));
323 }
324 }
325
326 /// Answer the instructions in the specified group, sorted by destination
327 /// register.
328 ///
329 /// # Parameters
330 /// - `group`: The group identifier.
331 ///
332 /// # Returns
333 /// The instruction group.
334 fn group(&self, group: usize) -> Vec<Instruction>
335 {
336 let mut instructions = self
337 .groups
338 .iter()
339 .filter_map(|(inst, g)| if *g == group { Some(inst) } else { None })
340 .map(|inst| self.replacements.get(inst).unwrap_or(inst).clone())
341 .collect::<Vec<_>>();
342 instructions.sort_by_key(Instruction::destination);
343 instructions
344 }
345
346 /// Rewrite the instructions in a commutative group such that the immediate
347 /// operands are read first and all registers are read in ascending order.
348 /// Hoisting the immediate operands to the front of a chain of commutative
349 /// instructions makes them more amenable to other optimizations. Do not
350 /// emit the instructions, just populate the replacement map.
351 ///
352 /// # Type Parameters
353 /// - `I`: The type of instruction to rewrite.
354 ///
355 /// # Parameters
356 /// - `group`: The group of instructions to rewrite.
357 /// - `constructor`: A function that constructs a new instruction of the
358 /// appropriate type from the destination and source operands of the
359 /// original instruction.
360 fn rewrite_commutative_group<I>(
361 &mut self,
362 group: usize,
363 constructor: impl Fn(AddressingMode, &[AddressingMode]) -> I
364 ) where
365 I: Into<Instruction>
366 {
367 // Collect all instructions within the specified commutative group,
368 // using any replacements that have already been established. Extract
369 // their operands and sort them in descending order, such that
370 // immediates are at the front and registers are at the back. Ensure
371 // that registers are written before they are read by preserving
372 // ascending order. Vectors can only efficiently pop from the back, so
373 // we reverse the order of the operands before iterating over them.
374 let instructions = self.group(group);
375 let mut ops = instructions
376 .iter()
377 .flat_map(Instruction::sources)
378 .collect::<Vec<_>>();
379 ops.sort();
380 ops.reverse();
381 let arity = ops.len() / instructions.len();
382 instructions.iter().for_each(|inst| {
383 let ops =
384 (0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
385 let new_inst =
386 constructor(inst.destination().unwrap(), &ops).into();
387 self.replacements.insert((*inst).clone(), new_inst);
388 });
389 }
390
391 /// Rewrite the instructions in a nearly commutative group such that the
392 /// commutative immediate operands are read first and all commutative
393 /// registers are read in ascending order. Hoisting the immediate operands
394 /// to the front of a chain of nearly commutative instructions makes them
395 /// more amenable to other optimizations. Do not emit the instructions, just
396 /// populate the replacement map. The first instruction in the group does
397 /// not commute with the rest, so we handle it specially and emit it last.
398 ///
399 /// # Type Parameters
400 /// - `C`: The type of commutative instruction to rewrite.
401 /// - `F`: The type of final instruction to rewrite.
402 ///
403 /// # Parameters
404 /// - `group`: The group of instructions to rewrite.
405 /// - `commutative_constructor`: A function that constructs a new
406 /// commutative instruction of the appropriate type from the destination
407 /// and source operands of the original instruction.
408 /// - `final_constructor`: A function that constructs a new terminal
409 /// instruction of the appropriate type from the destination and source
410 /// operands of the original instruction.
411 fn rewrite_nearly_commutative_group<C, F>(
412 &mut self,
413 group: usize,
414 commutative_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> C,
415 final_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> F
416 ) where
417 C: Into<Instruction>,
418 F: Into<Instruction>
419 {
420 // Collect all instructions within the nearly commutative group. If
421 // there's only a single element in the group, then there's nothing to
422 // do.
423 let mut instructions = self.group(group);
424 if instructions.len() > 1
425 {
426 // Extract all of the destination registers from the complete chain,
427 // including the terminal instruction. They are already sorted.
428 let mut targets = instructions
429 .iter()
430 .flat_map(Instruction::destination)
431 .collect::<Vec<_>>();
432 // Extract all of the operands from the complete chain.
433 let mut ops = instructions
434 .iter()
435 .flat_map(Instruction::sources)
436 .collect::<Vec<_>>();
437 // The very first operand is not commutative with the rest, so we
438 // handle it specially. We'll pop it off the front of the list and
439 // re-inject it when we emit the final non-commutative instruction.
440 let first_op = ops.remove(0);
441 // The next to last operand is the destination of the last
442 // commutative instruction. We swap it with the last operand so that
443 // we can pop it off the back of the list. We sort the remaining
444 // operands, which are all part of the commutative chain.
445 let len = ops.len();
446 ops.swap(len - 2, len - 1);
447 let commutative_result = ops.pop().unwrap();
448 ops.sort();
449 // The first instruction is non-commutative, and needs to be emitted
450 // last. We remove it from the list of instructions, rewrite it with
451 // the correct operands, and save it to emit at the end.
452 let first_inst = instructions.remove(0);
453 let terminal = final_constructor(
454 targets.pop().unwrap(),
455 &[first_op, commutative_result]
456 )
457 .into();
458 // Reverse the remaining operands and destinations prior to
459 // traversing the commutative instructions.
460 ops.reverse();
461 targets.reverse();
462 // Rewrite the commutative instructions in the chain, now that
463 // everything is in the correct order for further optimization.
464 let arity = ops.len() / instructions.len();
465 let mut previous = first_inst.clone();
466 instructions.iter().for_each(|inst| {
467 let dest = targets.pop().unwrap();
468 let ops =
469 (0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
470 let new_inst = commutative_constructor(dest, &ops).into();
471 self.replacements.insert(previous.clone(), new_inst);
472 previous = inst.clone();
473 });
474 self.replacements.insert(previous, terminal);
475 }
476 }
477}