1use super::MegakernelWorkItem;
7
8mod prologue;
9pub use prologue::shared_prologue_length;
10
11pub(super) const MAX_DENSE_FUSION_ITEMS: usize = 4096;
15
16#[derive(Debug, Default)]
21pub struct FusionSelectionScratch {
22 order: Vec<usize>,
23 result: Vec<u32>,
24 conflict_degrees: Vec<u32>,
25 selected: Vec<usize>,
26}
27
28impl FusionSelectionScratch {
29 #[must_use]
31 pub fn result(&self) -> &[u32] {
32 &self.result
33 }
34
35 #[must_use]
37 pub fn take_result(&mut self) -> Vec<u32> {
38 std::mem::take(&mut self.result)
39 }
40
41 fn prepare(&mut self, n: usize) {
42 self.order.clear();
43 self.order.extend(0..n);
44 self.result.clear();
45 self.result.resize(n, 0);
46 self.conflict_degrees.clear();
47 self.conflict_degrees.resize(n, 0);
48 self.selected.clear();
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum FusionSelectionError {
55 ExchangeSizeOverflow {
57 n: usize,
59 },
60 CostLen {
62 expected: usize,
64 actual: usize,
66 },
67 ExchangeAdjLen {
69 expected: usize,
71 actual: usize,
73 },
74}
75
76impl std::fmt::Display for FusionSelectionError {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 Self::ExchangeSizeOverflow { n } => write!(
80 f,
81 "megakernel fusion selector n*n overflow for n={n}. Fix: shard the work batch before fusion selection."
82 ),
83 Self::CostLen { expected, actual } => write!(
84 f,
85 "megakernel fusion selector cost length {actual} does not match n={expected}. Fix: pass one cost per work item."
86 ),
87 Self::ExchangeAdjLen { expected, actual } => write!(
88 f,
89 "megakernel fusion selector exchange_adj length {actual} does not match n*n={expected}. Fix: pass a dense row-major n*n exchange graph."
90 ),
91 }
92 }
93}
94
95impl std::error::Error for FusionSelectionError {}
96
97fn validate_selector_shape(
98 cost_len: usize,
99 n: u32,
100 exchange_adj_len: usize,
101) -> Result<(usize, usize), FusionSelectionError> {
102 let n_usize = usize::try_from(n)
103 .map_err(|_| FusionSelectionError::ExchangeSizeOverflow { n: usize::MAX })?;
104 let cells = n_usize
105 .checked_mul(n_usize)
106 .ok_or(FusionSelectionError::ExchangeSizeOverflow { n: n_usize })?;
107 if cost_len != n_usize {
108 return Err(FusionSelectionError::CostLen {
109 expected: n_usize,
110 actual: cost_len,
111 });
112 }
113 if exchange_adj_len != cells {
114 return Err(FusionSelectionError::ExchangeAdjLen {
115 expected: cells,
116 actual: exchange_adj_len,
117 });
118 }
119 Ok((n_usize, cells))
120}
121
122#[derive(Debug, Default)]
127pub struct CompactFusionPlanningScratch {
128 costs_q16: Vec<u16>,
129 stalks: Vec<f32>,
130 diffused_stalks: Vec<f32>,
131 effective_divergence: Vec<u32>,
132 deltas: Vec<f32>,
133 sorted_deltas: Vec<f32>,
134 exchange_adj: Vec<u32>,
135 order: Vec<usize>,
136 selection: FusionSelectionScratch,
137}
138
139impl CompactFusionPlanningScratch {
140 #[must_use]
142 pub fn exchange_adj(&self) -> &[u32] {
143 &self.exchange_adj
144 }
145
146 #[must_use]
148 pub fn selected(&self) -> &[u32] {
149 self.selection.result()
150 }
151}
152
153pub fn plan_compact_fusion_into<'a>(
158 work_items: &[MegakernelWorkItem],
159 scratch: &'a mut CompactFusionPlanningScratch,
160) -> &'a [u32] {
161 let n = work_items.len();
162 if n > MAX_DENSE_FUSION_ITEMS {
163 scratch.selection.prepare(n);
164 scratch.selection.result.fill(1);
165 scratch.exchange_adj.clear();
166 return scratch.selection.result();
167 }
168
169 if n == 0 {
170 scratch.costs_q16.clear();
171 scratch.stalks.clear();
172 scratch.diffused_stalks.clear();
173 scratch.effective_divergence.clear();
174 scratch.deltas.clear();
175 scratch.sorted_deltas.clear();
176 scratch.exchange_adj.clear();
177 scratch.selection.prepare(0);
178 return scratch.selection.result();
179 }
180
181 scratch.costs_q16.clear();
182 scratch.costs_q16.resize(n, u16::MAX);
183
184 scratch.stalks.clear();
185 scratch.stalks.extend(
186 work_items
187 .iter()
188 .enumerate()
189 .map(|(item_idx, _item)| (item_idx as f32) * 0.001),
190 );
191 scratch.diffused_stalks.clear();
192 scratch.diffused_stalks.extend_from_slice(&scratch.stalks);
193 for _ in 0..8 {
194 for value in &mut scratch.diffused_stalks {
195 *value -= 0.5_f32 * 0.7_f32 * *value;
196 }
197 }
198
199 let divergence_threshold = 0.05_f32;
200 let mut delta_sum = 0.0_f32;
201 let mut delta_max = 0.0_f32;
202 scratch.effective_divergence.clear();
203 for (&initial, &diffused) in scratch.stalks.iter().zip(scratch.diffused_stalks.iter()) {
204 let delta = (initial - diffused).abs();
205 delta_sum += delta;
206 delta_max = delta_max.max(delta);
207 scratch
208 .effective_divergence
209 .push(u32::from(delta > divergence_threshold));
210 }
211
212 let n_f32 = n as f32;
213 let gap_signal = if delta_max > 0.0_f32 && n_f32 > 0.0_f32 {
214 delta_sum / (n_f32 * delta_max)
215 } else {
216 1.0_f32
217 };
218 if gap_signal < 0.3 {
219 scratch.deltas.clear();
220 scratch.deltas.extend(
221 scratch
222 .stalks
223 .iter()
224 .zip(scratch.diffused_stalks.iter())
225 .map(|(s, d)| (s - d).abs()),
226 );
227 scratch.sorted_deltas.clear();
228 scratch.sorted_deltas.extend_from_slice(&scratch.deltas);
229 scratch
230 .sorted_deltas
231 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
232 let median = scratch
233 .sorted_deltas
234 .get(scratch.sorted_deltas.len() / 2)
235 .copied()
236 .unwrap_or(0.0);
237 for (flag, delta) in scratch
238 .effective_divergence
239 .iter_mut()
240 .zip(scratch.deltas.iter())
241 {
242 if *delta < median {
243 *flag = 0;
244 }
245 }
246 }
247
248 scratch.exchange_adj.clear();
249 let dense_cells = n * n;
250 scratch.exchange_adj.resize(dense_cells, 0);
251 let mut has_exchange_conflict = false;
252
253 let mut has_op_conflict = false;
254 scratch.order.clear();
255 scratch.order.extend(0..n);
256 if scratch.order.len() > 1 {
257 scratch
258 .order
259 .sort_unstable_by_key(|&item_idx| work_items[item_idx].op_handle);
260 if scratch
261 .order
262 .windows(2)
263 .any(|window| work_items[window[0]].op_handle == work_items[window[1]].op_handle)
264 {
265 has_op_conflict = true;
266 }
267 }
268 let has_output_input_chain = (0..n.checked_sub(1).unwrap_or(0)).any(|i| {
269 work_items.get(i).map(|w| w.output_handle) == work_items.get(i + 1).map(|w| w.input_handle)
270 });
271 let has_divergence_conflict = scratch.effective_divergence.iter().any(|&v| v != 0);
272 scratch.selection.prepare(n);
273
274 if !has_op_conflict && !has_divergence_conflict {
275 if has_output_input_chain {
276 for cost in scratch.costs_q16.iter_mut() {
277 *cost = discount_q16(*cost, 3_276);
278 }
279 }
280
281 scratch.selection.result.fill(1);
282 return scratch.selection.result();
283 }
284
285 {
286 let conflict_degrees = &mut scratch.selection.conflict_degrees;
287 for i in 0..n {
288 let row_start = i * n;
289 for j in 0..n {
290 if i == j {
291 continue;
292 }
293 let same_op = work_items[i].op_handle == work_items[j].op_handle;
294 if n <= 32 && same_op {
295 scratch.costs_q16[i] = discount_q16(scratch.costs_q16[i], 3_276);
296 }
297 let divergent =
298 scratch.effective_divergence[i] != 0 && scratch.effective_divergence[j] != 0;
299 if same_op || divergent {
300 scratch.exchange_adj[row_start + j] = 1;
301 if i < j {
302 conflict_degrees[i] = increment_degree(conflict_degrees[i]);
303 conflict_degrees[j] = increment_degree(conflict_degrees[j]);
304 }
305 has_exchange_conflict = true;
306 }
307 }
308 }
309 }
310 if has_output_input_chain {
311 for cost in scratch.costs_q16.iter_mut() {
312 *cost = discount_q16(*cost, 3_276);
313 }
314 }
315 if !has_exchange_conflict {
316 scratch.selection.result.fill(1);
317 return scratch.selection.result();
318 }
319
320 let conflict_degrees = &scratch.selection.conflict_degrees;
321 scratch.selection.order.sort_unstable_by(|&a, &b| {
322 scratch.costs_q16[a]
323 .cmp(&scratch.costs_q16[b])
324 .then_with(|| conflict_degrees[a].cmp(&conflict_degrees[b]))
325 .then_with(|| a.cmp(&b))
326 });
327 select_ordered_maximal(
328 &scratch.exchange_adj,
329 n,
330 &scratch.selection.order,
331 &mut scratch.selection.selected,
332 &mut scratch.selection.result,
333 );
334 scratch.selection.result()
335}
336
337#[must_use]
345pub fn select_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
346 let mut scratch = FusionSelectionScratch::default();
347 select_fused_subset_into(costs, n, exchange_adj, &mut scratch);
348 scratch.take_result()
349}
350
351pub fn select_fused_subset_into(
353 costs: &[f64],
354 n: u32,
355 exchange_adj: &[u32],
356 scratch: &mut FusionSelectionScratch,
357) {
358 if let Ok((n_usize, _cells)) = validate_selector_shape(costs.len(), n, exchange_adj.len()) {
359 if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
360 scratch.prepare(n_usize);
361 scratch.result.fill(1);
362 return;
363 }
364 }
365 if select_fused_subset_checked_into(costs, n, exchange_adj, scratch).is_err() {
366 scratch.prepare(0);
367 }
368}
369
370pub fn select_fused_subset_checked_into(
372 costs: &[f64],
373 n: u32,
374 exchange_adj: &[u32],
375 scratch: &mut FusionSelectionScratch,
376) -> Result<(), FusionSelectionError> {
377 let (n_usize, _cells) = validate_selector_shape(costs.len(), n, exchange_adj.len())?;
378 if n_usize > MAX_DENSE_FUSION_ITEMS {
379 scratch.prepare(n_usize);
380 scratch.result.fill(1);
381 return Ok(());
382 }
383 scratch.prepare(n_usize);
384 if exchange_adj.iter().all(|&edge| edge == 0) {
385 scratch.result.fill(1);
386 return Ok(());
387 }
388 if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
389 {
390 scratch.result.fill(1);
391 return Ok(());
392 }
393 scratch.order.sort_unstable_by(|&a, &b| {
394 costs[a]
395 .total_cmp(&costs[b])
396 .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
397 .then_with(|| a.cmp(&b))
398 });
399 select_ordered_maximal(
400 exchange_adj,
401 n_usize,
402 &scratch.order,
403 &mut scratch.selected,
404 &mut scratch.result,
405 );
406 Ok(())
407}
408
409#[must_use]
415pub fn select_fused_subset_compact(costs_q16: &[u16], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
416 let mut scratch = FusionSelectionScratch::default();
417 select_fused_subset_compact_into(costs_q16, n, exchange_adj, &mut scratch);
418 scratch.take_result()
419}
420
421pub fn select_fused_subset_compact_into(
423 costs_q16: &[u16],
424 n: u32,
425 exchange_adj: &[u32],
426 scratch: &mut FusionSelectionScratch,
427) {
428 if let Ok((n_usize, _cells)) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len()) {
429 if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
430 scratch.prepare(n_usize);
431 scratch.result.fill(1);
432 return;
433 }
434 }
435 if select_fused_subset_compact_checked_into(costs_q16, n, exchange_adj, scratch).is_err() {
436 scratch.prepare(0);
437 }
438}
439
440pub fn select_fused_subset_compact_checked_into(
442 costs_q16: &[u16],
443 n: u32,
444 exchange_adj: &[u32],
445 scratch: &mut FusionSelectionScratch,
446) -> Result<(), FusionSelectionError> {
447 let (n_usize, _cells) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len())?;
448 if n_usize > MAX_DENSE_FUSION_ITEMS {
449 scratch.prepare(n_usize);
450 scratch.result.fill(1);
451 return Ok(());
452 }
453 scratch.prepare(n_usize);
454 if exchange_adj.iter().all(|&edge| edge == 0) {
455 scratch.result.fill(1);
456 return Ok(());
457 }
458 if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
459 {
460 scratch.result.fill(1);
461 return Ok(());
462 }
463 scratch.order.sort_unstable_by(|&a, &b| {
464 costs_q16[a]
465 .cmp(&costs_q16[b])
466 .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
467 .then_with(|| a.cmp(&b))
468 });
469 select_ordered_maximal(
470 exchange_adj,
471 n_usize,
472 &scratch.order,
473 &mut scratch.selected,
474 &mut scratch.result,
475 );
476 Ok(())
477}
478
479#[must_use]
482
483pub fn select_optimal_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
484 select_fused_subset(costs, n, exchange_adj)
485}
486
487#[must_use]
489pub fn select_fused_subset_with_rate(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
490 select_fused_subset(costs, n, exchange_adj)
491}
492
493#[must_use]
500pub fn select_fused_subset_pruned(
501 costs: &[f64],
502 n: u32,
503 exchange_adj: &[u32],
504 dead_mask: &[bool],
505) -> Vec<u32> {
506 let mut selection = select_fused_subset(costs, n, exchange_adj);
507 prune_dead_arms_inplace(&mut selection, dead_mask);
508 selection
509}
510
511pub fn select_fused_subset_pruned_into(
513 costs: &[f64],
514 n: u32,
515 exchange_adj: &[u32],
516 dead_mask: &[bool],
517 scratch: &mut FusionSelectionScratch,
518) {
519 select_fused_subset_into(costs, n, exchange_adj, scratch);
520 prune_dead_arms_inplace(&mut scratch.result, dead_mask);
521}
522
523pub fn prune_dead_arms_inplace(selection: &mut [u32], dead_mask: &[bool]) -> u32 {
545 if selection.len() != dead_mask.len() {
546 return 0;
547 }
548 let mut eliminated = 0_u32;
549 for (slot, &dead) in selection.iter_mut().zip(dead_mask.iter()) {
550 if dead && *slot != 0 {
551 *slot = 0;
552 eliminated = eliminated.saturating_add(1);
553 }
554 }
555 eliminated
556}
557
558fn compute_conflict_degrees_with_conflict(exchange_adj: &[u32], n: usize, out: &mut [u32]) -> bool {
559 debug_assert_eq!(out.len(), n);
560 out.fill(0);
561 let mut has_conflict = false;
562 for i in 0..n {
563 let row = i * n;
564 for j in (i + 1)..n {
565 if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
566 out[i] = increment_degree(out[i]);
567 out[j] = increment_degree(out[j]);
568 has_conflict = true;
569 }
570 }
571 }
572 has_conflict
573}
574
575fn discount_q16(value: u16, amount: u16) -> u16 {
576 value.saturating_sub(amount)
577}
578
579fn increment_degree(value: u32) -> u32 {
580 value.saturating_add(1)
581}
582
583fn select_ordered_maximal(
584 exchange_adj: &[u32],
585 n: usize,
586 order: &[usize],
587 selected: &mut Vec<usize>,
588 result: &mut [u32],
589) {
590 result.fill(0);
591 selected.clear();
592
593 if n == 0 {
594 return;
595 }
596
597 if n <= 64 {
598 let mut conflict_masks = [0_u64; 64];
599 for i in 0..n {
600 let row = i * n;
601 let mut mask = 0_u64;
602 for j in 0..n {
603 if i == j {
604 continue;
605 }
606 if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
607 mask |= 1_u64 << j;
608 }
609 }
610 conflict_masks[i] = mask;
611 }
612
613 let mut selected_mask = 0_u64;
614 for &item in order {
615 if item >= n {
616 continue;
617 }
618 if conflict_masks[item] & selected_mask == 0 {
619 result[item] = 1;
620 selected_mask |= 1_u64 << item;
621 selected.push(item);
622 }
623 }
624 return;
625 }
626
627 if n <= 128 {
628 let mut conflict_masks_lo = [0_u64; 128];
629 let mut conflict_masks_hi = [0_u64; 128];
630 for i in 0..n {
631 let row = i * n;
632 let mut mask_lo = 0_u64;
633 let mut mask_hi = 0_u64;
634 for j in 0..n {
635 if i == j {
636 continue;
637 }
638 if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
639 if j < 64 {
640 mask_lo |= 1_u64 << j;
641 } else {
642 mask_hi |= 1_u64 << (j - 64);
643 }
644 }
645 }
646 conflict_masks_lo[i] = mask_lo;
647 conflict_masks_hi[i] = mask_hi;
648 }
649
650 let mut selected_lo = 0_u64;
651 let mut selected_hi = 0_u64;
652 for &item in order {
653 if item >= n {
654 continue;
655 }
656 let conflict = (conflict_masks_lo[item] & selected_lo) != 0
657 || (conflict_masks_hi[item] & selected_hi) != 0;
658 if !conflict {
659 result[item] = 1;
660 if item < 64 {
661 selected_lo |= 1_u64 << item;
662 } else {
663 selected_hi |= 1_u64 << (item - 64);
664 }
665 selected.push(item);
666 }
667 }
668 return;
669 }
670
671 if n <= 192 {
672 let mut conflict_masks_0 = [0_u64; 192];
673 let mut conflict_masks_1 = [0_u64; 192];
674 let mut conflict_masks_2 = [0_u64; 192];
675 for i in 0..n {
676 let row = i * n;
677 let mut mask_0 = 0_u64;
678 let mut mask_1 = 0_u64;
679 let mut mask_2 = 0_u64;
680 for j in 0..n {
681 if i == j {
682 continue;
683 }
684 if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
685 match j / 64 {
686 0 => mask_0 |= 1_u64 << (j % 64),
687 1 => mask_1 |= 1_u64 << (j % 64),
688 2 => mask_2 |= 1_u64 << (j % 64),
689 _ => {}
690 }
691 }
692 }
693 conflict_masks_0[i] = mask_0;
694 conflict_masks_1[i] = mask_1;
695 conflict_masks_2[i] = mask_2;
696 }
697
698 let mut selected_0 = 0_u64;
699 let mut selected_1 = 0_u64;
700 let mut selected_2 = 0_u64;
701 for &item in order {
702 if item >= n {
703 continue;
704 }
705 let conflict = (conflict_masks_0[item] & selected_0 != 0)
706 || (conflict_masks_1[item] & selected_1 != 0)
707 || (conflict_masks_2[item] & selected_2 != 0);
708 if !conflict {
709 result[item] = 1;
710 let bit = 1_u64 << (item % 64);
711 match item / 64 {
712 0 => selected_0 |= bit,
713 1 => selected_1 |= bit,
714 2 => selected_2 |= bit,
715 _ => {}
716 }
717 selected.push(item);
718 }
719 }
720 return;
721 }
722
723 if n <= 256 {
724 let mut conflict_masks_0 = [0_u64; 256];
725 let mut conflict_masks_1 = [0_u64; 256];
726 let mut conflict_masks_2 = [0_u64; 256];
727 let mut conflict_masks_3 = [0_u64; 256];
728 for i in 0..n {
729 let row = i * n;
730 let mut mask_0 = 0_u64;
731 let mut mask_1 = 0_u64;
732 let mut mask_2 = 0_u64;
733 let mut mask_3 = 0_u64;
734 for j in 0..n {
735 if i == j {
736 continue;
737 }
738 if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
739 match j / 64 {
740 0 => mask_0 |= 1_u64 << (j % 64),
741 1 => mask_1 |= 1_u64 << (j % 64),
742 2 => mask_2 |= 1_u64 << (j % 64),
743 _ => mask_3 |= 1_u64 << (j % 64),
744 }
745 }
746 }
747 conflict_masks_0[i] = mask_0;
748 conflict_masks_1[i] = mask_1;
749 conflict_masks_2[i] = mask_2;
750 conflict_masks_3[i] = mask_3;
751 }
752
753 let mut selected_0 = 0_u64;
754 let mut selected_1 = 0_u64;
755 let mut selected_2 = 0_u64;
756 let mut selected_3 = 0_u64;
757 for &item in order {
758 if item >= n {
759 continue;
760 }
761 let conflict = (conflict_masks_0[item] & selected_0 != 0)
762 || (conflict_masks_1[item] & selected_1 != 0)
763 || (conflict_masks_2[item] & selected_2 != 0)
764 || (conflict_masks_3[item] & selected_3 != 0);
765 if !conflict {
766 result[item] = 1;
767 let bit = 1_u64 << (item % 64);
768 match item / 64 {
769 0 => selected_0 |= bit,
770 1 => selected_1 |= bit,
771 2 => selected_2 |= bit,
772 _ => selected_3 |= bit,
773 }
774 selected.push(item);
775 }
776 }
777 return;
778 }
779
780 let chunks = n.div_ceil(64);
781 let mut conflict_masks = vec![0_u64; n * chunks];
782 for i in 0..n {
783 for j in (i + 1)..n {
784 if exchange_adj[i * n + j] != 0 || exchange_adj[j * n + i] != 0 {
785 let i_word = i / 64;
786 let i_bit = 1_u64 << (i % 64);
787 let j_word = j / 64;
788 let j_bit = 1_u64 << (j % 64);
789
790 let i_base = i * chunks;
791 let j_base = j * chunks;
792 conflict_masks[i_base + j_word] |= j_bit;
793 conflict_masks[j_base + i_word] |= i_bit;
794 }
795 }
796 }
797
798 let mut selected_mask = vec![0_u64; chunks];
799 for &item in order {
800 if item >= n {
801 continue;
802 }
803 let base = item * chunks;
804 let mut conflict = false;
805 for chunk in 0..chunks {
806 if conflict_masks[base + chunk] & selected_mask[chunk] != 0 {
807 conflict = true;
808 break;
809 }
810 }
811 if !conflict {
812 result[item] = 1;
813 selected.push(item);
814 selected_mask[item / 64] |= 1_u64 << (item % 64);
815 }
816 }
817}
818
819#[cfg(test)]
820mod tests;