xdy/optimizer/
commutation.rs

1//! # Commutation
2//!
3//! Commutative operations can reorder their operands without changing the
4//! meaning or result. Both addition and multiplication are commutative, so
5//! reordering the operands can create additional opportunities for folding
6//! constants and reducing strength. This pass reorders the operands of
7//! commutative operations so that immediates appear before registers. The
8//! constant commuter requires the function to be in static single assignment
9//! (SSA) form.
10
11use std::{
12	cmp::{max, min},
13	collections::{BTreeSet, HashMap}
14};
15
16use crate::{
17	Add, AddressingMode, CanVisitInstructions, DependencyAnalyzer, Div,
18	DropHighest, DropLowest, Exp, Instruction, InstructionVisitor, Mod, Mul,
19	Neg, Return, RollCustomDice, RollRange, RollStandardDice, Sub,
20	SumRollingRecord
21};
22
23////////////////////////////////////////////////////////////////////////////////
24//                                Commutation.                                //
25////////////////////////////////////////////////////////////////////////////////
26
27/// A commuter that shuffles the immediate operands of commutative instructions
28/// to the front in order to simplify constant folding and strength reduction.
29/// The standard optimizer applies this pass to a [function](Function) before
30/// folding constants or reducing strength.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ConstantCommuter<'inst>
33{
34	/// The instructions to organize.
35	instructions: &'inst [Instruction],
36
37	/// The dependency analyzer.
38	analyzer: DependencyAnalyzer<'inst>,
39
40	/// An arrangement of the instructions of a function body into commutative
41	/// groups. Each instruction is branded with a group number, and all
42	/// instructions sharing the same group number are commutative with respect
43	/// to each other.
44	groups: HashMap<Instruction, usize>,
45
46	/// The next group number to assign.
47	next_group: usize,
48
49	/// Replacement instructions, indexed by original instructions.
50	replacements: HashMap<Instruction, Instruction>
51}
52
53impl<'inst> ConstantCommuter<'inst>
54{
55	/// Construct a new instruction organizer.
56	///
57	/// # Parameters
58	/// - `instructions`: The instructions to organize.
59	///
60	/// # Returns
61	/// The new instruction organizer.
62	pub fn commute(instructions: &'inst [Instruction]) -> Vec<Instruction>
63	{
64		let mut commuter = Self {
65			instructions,
66			analyzer: DependencyAnalyzer::analyze(instructions),
67			groups: HashMap::new(),
68			next_group: 0,
69			replacements: HashMap::new()
70		};
71		// Visit each instruction in the function body to organize them into
72		// commutative groups.
73		for instruction in instructions
74		{
75			instruction.visit(&mut commuter).unwrap();
76		}
77		// For each commutative group, rewrite the instructions to shuffle their
78		// immediate operands to the front of the group. Use the first
79		// instruction of a group to classify the type of the group.
80		let groups = commuter.groups.values().copied().collect::<BTreeSet<_>>();
81		for group in groups
82		{
83			let representative = commuter
84				.instructions
85				.iter()
86				.find(|inst| commuter.groups.get(inst) == Some(&group))
87				.unwrap();
88			match representative
89			{
90				Instruction::DropLowest(_) => commuter
91					.rewrite_commutative_group(group, |dest, srcs| {
92						DropLowest {
93							dest: dest.try_into().unwrap(),
94							count: srcs[0]
95						}
96					}),
97				Instruction::DropHighest(_) => commuter
98					.rewrite_commutative_group(group, |dest, srcs| {
99						DropHighest {
100							dest: dest.try_into().unwrap(),
101							count: srcs[0]
102						}
103					}),
104				Instruction::Add(_) =>
105				{
106					commuter.rewrite_commutative_group(group, |dest, srcs| {
107						Add {
108							dest: dest.try_into().unwrap(),
109							op1: srcs[0],
110							op2: srcs[1]
111						}
112					})
113				},
114				Instruction::Sub(_) => commuter
115					.rewrite_nearly_commutative_group(
116						group,
117						|dest, srcs| Add {
118							dest: dest.try_into().unwrap(),
119							op1: srcs[0],
120							op2: srcs[1]
121						},
122						|dest, srcs| Sub {
123							dest: dest.try_into().unwrap(),
124							op1: srcs[0],
125							op2: srcs[1]
126						}
127					),
128				Instruction::Mul(_) =>
129				{
130					commuter.rewrite_commutative_group(group, |dest, srcs| {
131						Mul {
132							dest: dest.try_into().unwrap(),
133							op1: srcs[0],
134							op2: srcs[1]
135						}
136					})
137				},
138				Instruction::Div(_) => commuter
139					.rewrite_nearly_commutative_group(
140						group,
141						|dest, srcs| Mul {
142							dest: dest.try_into().unwrap(),
143							op1: srcs[0],
144							op2: srcs[1]
145						},
146						|dest, srcs| Div {
147							dest: dest.try_into().unwrap(),
148							op1: srcs[0],
149							op2: srcs[1]
150						}
151					),
152				_ =>
153				{}
154			}
155		}
156		// Answer the replacement function body.
157		instructions
158			.iter()
159			.map(|inst| commuter.replacements.get(inst).unwrap_or(inst))
160			.cloned()
161			.collect()
162	}
163}
164
165impl InstructionVisitor<()> for ConstantCommuter<'_>
166{
167	fn visit_roll_range(&mut self, _inst: &RollRange) -> Result<(), ()>
168	{
169		Ok(())
170	}
171
172	fn visit_roll_standard_dice(
173		&mut self,
174		_inst: &RollStandardDice
175	) -> Result<(), ()>
176	{
177		Ok(())
178	}
179
180	fn visit_roll_custom_dice(
181		&mut self,
182		_inst: &RollCustomDice
183	) -> Result<(), ()>
184	{
185		Ok(())
186	}
187
188	fn visit_drop_lowest(&mut self, _inst: &DropLowest) -> Result<(), ()>
189	{
190		Ok(())
191	}
192
193	fn visit_drop_highest(&mut self, _inst: &DropHighest) -> Result<(), ()>
194	{
195		Ok(())
196	}
197
198	fn visit_sum_rolling_record(
199		&mut self,
200		_inst: &SumRollingRecord
201	) -> Result<(), ()>
202	{
203		Ok(())
204	}
205
206	fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
207	{
208		self.organize(*inst, |inst| matches!(inst, Instruction::Add(_)))
209	}
210
211	fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
212	{
213		// Subtraction is not commutative, but a chain of subtractions can be
214		// rewritten as a single subtraction with the leading operand and the
215		// sum of the subtrahends.
216		self.organize(*inst, |inst| matches!(inst, Instruction::Sub(_)))
217	}
218
219	fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
220	{
221		self.organize(*inst, |inst| matches!(inst, Instruction::Mul(_)))
222	}
223
224	fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
225	{
226		// Division is not commutative, but a chain of divisions can be
227		// rewritten as a single division with the leading operand and the
228		// product of the divisors.
229		self.organize(*inst, |inst| matches!(inst, Instruction::Div(_)))
230	}
231
232	fn visit_mod(&mut self, _inst: &Mod) -> Result<(), ()> { Ok(()) }
233
234	fn visit_exp(&mut self, _inst: &Exp) -> Result<(), ()> { Ok(()) }
235
236	fn visit_neg(&mut self, _inst: &Neg) -> Result<(), ()> { Ok(()) }
237
238	fn visit_return(&mut self, _inst: &Return) -> Result<(), ()> { Ok(()) }
239}
240
241impl ConstantCommuter<'_>
242{
243	/// Place the specified commutative instruction into a group, using the
244	/// supplied filter to recognize other instructions of the same type.
245	///
246	/// # Parameters
247	/// - `inst`: The instruction to organize.
248	/// - `filter`: A function that answers `true` if the instruction is of the
249	///   same type as the one being organized.
250	fn organize(
251		&mut self,
252		inst: impl Into<Instruction>,
253		filter: impl Fn(&Instruction) -> bool
254	) -> Result<(), ()>
255	{
256		let inst = inst.into();
257		// Obtain our group, creating it if necessary.
258		let group = self.group_id(&inst);
259		// Merge each reader's group with ours if the reader is the same type of
260		// instruction as us. This reduces the total number of expensive merges.
261		for reader in self
262			.analyzer
263			.readers()
264			.get(&inst.destination().unwrap())
265			.unwrap()
266			.clone()
267		{
268			let reader = &self.instructions[reader.0];
269			if filter(reader)
270			{
271				let reader_group = self.group_id(&reader.clone());
272				self.merge_group_ids(group, reader_group);
273			}
274		}
275		Ok(())
276	}
277
278	/// Answer the identifier of the group to which the specified instruction
279	/// belongs, creating a new group and branding the instruction if
280	/// necessary.
281	///
282	/// # Parameters
283	/// - `inst`: The instruction.
284	///
285	/// # Returns
286	/// The requested group identifier.
287	fn group_id(&mut self, inst: &Instruction) -> usize
288	{
289		match self.groups.get(inst)
290		{
291			Some(index) => *index,
292			None =>
293			{
294				let index = self.next_group;
295				self.groups.insert(inst.clone(), index);
296				self.next_group += 1;
297				index
298			}
299		}
300	}
301
302	/// Merge two groups of commutative instructions.
303	///
304	/// # Parameters
305	/// - `first`: The first group identifier.
306	/// - `second`: The second group identifier.
307	fn merge_group_ids(&mut self, first: usize, second: usize)
308	{
309		if first != second
310		{
311			// We prefer to keep the lower group number.
312			let min = min(first, second);
313			let max = max(first, second);
314			// For each instruction in the second group, move it to the first
315			// group.
316			self.groups.iter_mut().for_each(|(_, group)| {
317				if *group == max
318				{
319					*group = min;
320				}
321			});
322			assert!(self.groups.values().all(|&group| group != max));
323		}
324	}
325
326	/// Answer the instructions in the specified group, sorted by destination
327	/// register.
328	///
329	/// # Parameters
330	/// - `group`: The group identifier.
331	///
332	/// # Returns
333	/// The instruction group.
334	fn group(&self, group: usize) -> Vec<Instruction>
335	{
336		let mut instructions = self
337			.groups
338			.iter()
339			.filter_map(|(inst, g)| if *g == group { Some(inst) } else { None })
340			.map(|inst| self.replacements.get(inst).unwrap_or(inst).clone())
341			.collect::<Vec<_>>();
342		instructions.sort_by_key(Instruction::destination);
343		instructions
344	}
345
346	/// Rewrite the instructions in a commutative group such that the immediate
347	/// operands are read first and all registers are read in ascending order.
348	/// Hoisting the immediate operands to the front of a chain of commutative
349	/// instructions makes them more amenable to other optimizations. Do not
350	/// emit the instructions, just populate the replacement map.
351	///
352	/// # Type Parameters
353	/// - `I`: The type of instruction to rewrite.
354	///
355	/// # Parameters
356	/// - `group`: The group of instructions to rewrite.
357	/// - `constructor`: A function that constructs a new instruction of the
358	///   appropriate type from the destination and source operands of the
359	///   original instruction.
360	fn rewrite_commutative_group<I>(
361		&mut self,
362		group: usize,
363		constructor: impl Fn(AddressingMode, &[AddressingMode]) -> I
364	) where
365		I: Into<Instruction>
366	{
367		// Collect all instructions within the specified commutative group,
368		// using any replacements that have already been established. Extract
369		// their operands and sort them in descending order, such that
370		// immediates are at the front and registers are at the back. Ensure
371		// that registers are written before they are read by preserving
372		// ascending order. Vectors can only efficiently pop from the back, so
373		// we reverse the order of the operands before iterating over them.
374		let instructions = self.group(group);
375		let mut ops = instructions
376			.iter()
377			.flat_map(Instruction::sources)
378			.collect::<Vec<_>>();
379		ops.sort();
380		ops.reverse();
381		let arity = ops.len() / instructions.len();
382		instructions.iter().for_each(|inst| {
383			let ops =
384				(0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
385			let new_inst =
386				constructor(inst.destination().unwrap(), &ops).into();
387			self.replacements.insert((*inst).clone(), new_inst);
388		});
389	}
390
391	/// Rewrite the instructions in a nearly commutative group such that the
392	/// commutative immediate operands are read first and all commutative
393	/// registers are read in ascending order. Hoisting the immediate operands
394	/// to the front of a chain of nearly commutative instructions makes them
395	/// more amenable to other optimizations. Do not emit the instructions, just
396	/// populate the replacement map. The first instruction in the group does
397	/// not commute with the rest, so we handle it specially and emit it last.
398	///
399	/// # Type Parameters
400	/// - `C`: The type of commutative instruction to rewrite.
401	/// - `F`: The type of final instruction to rewrite.
402	///
403	/// # Parameters
404	/// - `group`: The group of instructions to rewrite.
405	/// - `commutative_constructor`: A function that constructs a new
406	///   commutative instruction of the appropriate type from the destination
407	///   and source operands of the original instruction.
408	/// - `final_constructor`: A function that constructs a new terminal
409	///   instruction of the appropriate type from the destination and source
410	///   operands of the original instruction.
411	fn rewrite_nearly_commutative_group<C, F>(
412		&mut self,
413		group: usize,
414		commutative_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> C,
415		final_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> F
416	) where
417		C: Into<Instruction>,
418		F: Into<Instruction>
419	{
420		// Collect all instructions within the nearly commutative group. If
421		// there's only a single element in the group, then there's nothing to
422		// do.
423		let mut instructions = self.group(group);
424		if instructions.len() > 1
425		{
426			// Extract all of the destination registers from the complete chain,
427			// including the terminal instruction. They are already sorted.
428			let mut targets = instructions
429				.iter()
430				.flat_map(Instruction::destination)
431				.collect::<Vec<_>>();
432			// Extract all of the operands from the complete chain.
433			let mut ops = instructions
434				.iter()
435				.flat_map(Instruction::sources)
436				.collect::<Vec<_>>();
437			// The very first operand is not commutative with the rest, so we
438			// handle it specially. We'll pop it off the front of the list and
439			// re-inject it when we emit the final non-commutative instruction.
440			let first_op = ops.remove(0);
441			// The next to last operand is the destination of the last
442			// commutative instruction. We swap it with the last operand so that
443			// we can pop it off the back of the list. We sort the remaining
444			// operands, which are all part of the commutative chain.
445			let len = ops.len();
446			ops.swap(len - 2, len - 1);
447			let commutative_result = ops.pop().unwrap();
448			ops.sort();
449			// The first instruction is non-commutative, and needs to be emitted
450			// last. We remove it from the list of instructions, rewrite it with
451			// the correct operands, and save it to emit at the end.
452			let first_inst = instructions.remove(0);
453			let terminal = final_constructor(
454				targets.pop().unwrap(),
455				&[first_op, commutative_result]
456			)
457			.into();
458			// Reverse the remaining operands and destinations prior to
459			// traversing the commutative instructions.
460			ops.reverse();
461			targets.reverse();
462			// Rewrite the commutative instructions in the chain, now that
463			// everything is in the correct order for further optimization.
464			let arity = ops.len() / instructions.len();
465			let mut previous = first_inst.clone();
466			instructions.iter().for_each(|inst| {
467				let dest = targets.pop().unwrap();
468				let ops =
469					(0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
470				let new_inst = commutative_constructor(dest, &ops).into();
471				self.replacements.insert(previous.clone(), new_inst);
472				previous = inst.clone();
473			});
474			self.replacements.insert(previous, terminal);
475		}
476	}
477}