Skip to main content

vyre_runtime/megakernel/planner/
fusion.rs

1//! Fusion-subset selection used by megakernel batch dispatchers.
2//!
3//! This runtime path is deliberately self-contained: it does not call
4//! self-substrate CPU reference solvers while preparing megakernel work.
5
6use super::MegakernelWorkItem;
7
8mod prologue;
9pub use prologue::shared_prologue_length;
10
11/// Hard cap for dense exchange-graph planning.
12///
13/// This avoids dense O(n*n) matrix growth in pathological batches.
14pub(super) const MAX_DENSE_FUSION_ITEMS: usize = 4096;
15
16/// Reusable buffers for megakernel fusion-subset selection.
17///
18/// Runtime schedulers can keep one scratch object per worker and avoid
19/// allocating the homotopy, seed, flow, and result buffers every batch.
20#[derive(Debug, Default)]
21pub struct FusionSelectionScratch {
22    order: Vec<usize>,
23    result: Vec<u32>,
24    conflict_degrees: Vec<u32>,
25    selected: Vec<usize>,
26}
27
28impl FusionSelectionScratch {
29    /// Selected 0/1 fusion vector from the last selector invocation.
30    #[must_use]
31    pub fn result(&self) -> &[u32] {
32        &self.result
33    }
34
35    /// Move out the current result while retaining the other scratch buffers.
36    #[must_use]
37    pub fn take_result(&mut self) -> Vec<u32> {
38        std::mem::take(&mut self.result)
39    }
40
41    fn prepare(&mut self, n: usize) {
42        self.order.clear();
43        self.order.extend(0..n);
44        self.result.clear();
45        self.result.resize(n, 0);
46        self.conflict_degrees.clear();
47        self.conflict_degrees.resize(n, 0);
48        self.selected.clear();
49    }
50}
51
52/// Input-shape error from megakernel fusion subset selection.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum FusionSelectionError {
55    /// `n * n` overflowed `usize`.
56    ExchangeSizeOverflow {
57        /// Requested item count.
58        n: usize,
59    },
60    /// Cost vector length did not match `n`.
61    CostLen {
62        /// Expected number of costs.
63        expected: usize,
64        /// Actual number of costs.
65        actual: usize,
66    },
67    /// Exchange adjacency length did not match `n * n`.
68    ExchangeAdjLen {
69        /// Expected number of row-major adjacency cells.
70        expected: usize,
71        /// Actual number of adjacency cells.
72        actual: usize,
73    },
74}
75
76impl std::fmt::Display for FusionSelectionError {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Self::ExchangeSizeOverflow { n } => write!(
80                f,
81                "megakernel fusion selector n*n overflow for n={n}. Fix: shard the work batch before fusion selection."
82            ),
83            Self::CostLen { expected, actual } => write!(
84                f,
85                "megakernel fusion selector cost length {actual} does not match n={expected}. Fix: pass one cost per work item."
86            ),
87            Self::ExchangeAdjLen { expected, actual } => write!(
88                f,
89                "megakernel fusion selector exchange_adj length {actual} does not match n*n={expected}. Fix: pass a dense row-major n*n exchange graph."
90            ),
91        }
92    }
93}
94
95impl std::error::Error for FusionSelectionError {}
96
97fn validate_selector_shape(
98    cost_len: usize,
99    n: u32,
100    exchange_adj_len: usize,
101) -> Result<(usize, usize), FusionSelectionError> {
102    let n_usize = usize::try_from(n)
103        .map_err(|_| FusionSelectionError::ExchangeSizeOverflow { n: usize::MAX })?;
104    let cells = n_usize
105        .checked_mul(n_usize)
106        .ok_or(FusionSelectionError::ExchangeSizeOverflow { n: n_usize })?;
107    if cost_len != n_usize {
108        return Err(FusionSelectionError::CostLen {
109            expected: n_usize,
110            actual: cost_len,
111        });
112    }
113    if exchange_adj_len != cells {
114        return Err(FusionSelectionError::ExchangeAdjLen {
115            expected: cells,
116            actual: exchange_adj_len,
117        });
118    }
119    Ok((n_usize, cells))
120}
121
122/// Reusable scratch for compact runtime fusion planning.
123///
124/// Concrete drivers own command submission. Runtime owns the queue-shaping
125/// policy: cost seeds, divergence flags, exchange graph, and selector output.
126#[derive(Debug, Default)]
127pub struct CompactFusionPlanningScratch {
128    costs_q16: Vec<u16>,
129    stalks: Vec<f32>,
130    diffused_stalks: Vec<f32>,
131    effective_divergence: Vec<u32>,
132    deltas: Vec<f32>,
133    sorted_deltas: Vec<f32>,
134    exchange_adj: Vec<u32>,
135    order: Vec<usize>,
136    selection: FusionSelectionScratch,
137}
138
139impl CompactFusionPlanningScratch {
140    /// Last exchange adjacency matrix, row-major `n*n`.
141    #[must_use]
142    pub fn exchange_adj(&self) -> &[u32] {
143        &self.exchange_adj
144    }
145
146    /// Last 0/1 selection vector.
147    #[must_use]
148    pub fn selected(&self) -> &[u32] {
149        self.selection.result()
150    }
151}
152
153/// Build the compact megakernel fusion plan for one work batch.
154///
155/// Returns the selector's 0/1 keep vector. The matching exchange adjacency is
156/// retained in `scratch.exchange_adj()` for provenance and diagnostics.
157pub fn plan_compact_fusion_into<'a>(
158    work_items: &[MegakernelWorkItem],
159    scratch: &'a mut CompactFusionPlanningScratch,
160) -> &'a [u32] {
161    let n = work_items.len();
162    if n > MAX_DENSE_FUSION_ITEMS {
163        scratch.selection.prepare(n);
164        scratch.selection.result.fill(1);
165        scratch.exchange_adj.clear();
166        return scratch.selection.result();
167    }
168
169    if n == 0 {
170        scratch.costs_q16.clear();
171        scratch.stalks.clear();
172        scratch.diffused_stalks.clear();
173        scratch.effective_divergence.clear();
174        scratch.deltas.clear();
175        scratch.sorted_deltas.clear();
176        scratch.exchange_adj.clear();
177        scratch.selection.prepare(0);
178        return scratch.selection.result();
179    }
180
181    scratch.costs_q16.clear();
182    scratch.costs_q16.resize(n, u16::MAX);
183
184    scratch.stalks.clear();
185    scratch.stalks.extend(
186        work_items
187            .iter()
188            .enumerate()
189            .map(|(item_idx, _item)| (item_idx as f32) * 0.001),
190    );
191    scratch.diffused_stalks.clear();
192    scratch.diffused_stalks.extend_from_slice(&scratch.stalks);
193    for _ in 0..8 {
194        for value in &mut scratch.diffused_stalks {
195            *value -= 0.5_f32 * 0.7_f32 * *value;
196        }
197    }
198
199    let divergence_threshold = 0.05_f32;
200    let mut delta_sum = 0.0_f32;
201    let mut delta_max = 0.0_f32;
202    scratch.effective_divergence.clear();
203    for (&initial, &diffused) in scratch.stalks.iter().zip(scratch.diffused_stalks.iter()) {
204        let delta = (initial - diffused).abs();
205        delta_sum += delta;
206        delta_max = delta_max.max(delta);
207        scratch
208            .effective_divergence
209            .push(u32::from(delta > divergence_threshold));
210    }
211
212    let n_f32 = n as f32;
213    let gap_signal = if delta_max > 0.0_f32 && n_f32 > 0.0_f32 {
214        delta_sum / (n_f32 * delta_max)
215    } else {
216        1.0_f32
217    };
218    if gap_signal < 0.3 {
219        scratch.deltas.clear();
220        scratch.deltas.extend(
221            scratch
222                .stalks
223                .iter()
224                .zip(scratch.diffused_stalks.iter())
225                .map(|(s, d)| (s - d).abs()),
226        );
227        scratch.sorted_deltas.clear();
228        scratch.sorted_deltas.extend_from_slice(&scratch.deltas);
229        scratch
230            .sorted_deltas
231            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
232        let median = scratch
233            .sorted_deltas
234            .get(scratch.sorted_deltas.len() / 2)
235            .copied()
236            .unwrap_or(0.0);
237        for (flag, delta) in scratch
238            .effective_divergence
239            .iter_mut()
240            .zip(scratch.deltas.iter())
241        {
242            if *delta < median {
243                *flag = 0;
244            }
245        }
246    }
247
248    scratch.exchange_adj.clear();
249    let dense_cells = n.checked_mul(n).unwrap_or_else(|| {
250        panic!(
251            "megakernel compact fusion exchange graph overflowed usize. Fix: shard the work batch before fusion planning."
252        )
253    });
254    scratch.exchange_adj.resize(dense_cells, 0);
255    let mut has_exchange_conflict = false;
256
257    let mut has_op_conflict = false;
258    scratch.order.clear();
259    scratch.order.extend(0..n);
260    if scratch.order.len() > 1 {
261        scratch
262            .order
263            .sort_unstable_by_key(|&item_idx| work_items[item_idx].op_handle);
264        if scratch
265            .order
266            .windows(2)
267            .any(|window| work_items[window[0]].op_handle == work_items[window[1]].op_handle)
268        {
269            has_op_conflict = true;
270        }
271    }
272    let has_output_input_chain = (0..n.checked_sub(1).unwrap_or(0)).any(|i| {
273        work_items.get(i).map(|w| w.output_handle) == work_items.get(i + 1).map(|w| w.input_handle)
274    });
275    let has_divergence_conflict = scratch.effective_divergence.iter().any(|&v| v != 0);
276    scratch.selection.prepare(n);
277
278    if !has_op_conflict && !has_divergence_conflict {
279        if has_output_input_chain {
280            for cost in scratch.costs_q16.iter_mut() {
281                *cost = discount_q16(*cost, 3_276);
282            }
283        }
284
285        scratch.selection.result.fill(1);
286        return scratch.selection.result();
287    }
288
289    {
290        let conflict_degrees = &mut scratch.selection.conflict_degrees;
291        for i in 0..n {
292            let row_start = i * n;
293            for j in 0..n {
294                if i == j {
295                    continue;
296                }
297                let same_op = work_items[i].op_handle == work_items[j].op_handle;
298                if n <= 32 && same_op {
299                    scratch.costs_q16[i] = discount_q16(scratch.costs_q16[i], 3_276);
300                }
301                let divergent =
302                    scratch.effective_divergence[i] != 0 && scratch.effective_divergence[j] != 0;
303                if same_op || divergent {
304                    scratch.exchange_adj[row_start + j] = 1;
305                    if i < j {
306                        conflict_degrees[i] = increment_degree(conflict_degrees[i]);
307                        conflict_degrees[j] = increment_degree(conflict_degrees[j]);
308                    }
309                    has_exchange_conflict = true;
310                }
311            }
312        }
313    }
314    if has_output_input_chain {
315        for cost in scratch.costs_q16.iter_mut() {
316            *cost = discount_q16(*cost, 3_276);
317        }
318    }
319    if !has_exchange_conflict {
320        scratch.selection.result.fill(1);
321        return scratch.selection.result();
322    }
323
324    let conflict_degrees = &scratch.selection.conflict_degrees;
325    scratch.selection.order.sort_unstable_by(|&a, &b| {
326        scratch.costs_q16[a]
327            .cmp(&scratch.costs_q16[b])
328            .then_with(|| conflict_degrees[a].cmp(&conflict_degrees[b]))
329            .then_with(|| a.cmp(&b))
330    });
331    select_ordered_maximal(
332        &scratch.exchange_adj,
333        n,
334        &scratch.selection.order,
335        &mut scratch.selection.selected,
336        &mut scratch.selection.result,
337    );
338    scratch.selection.result()
339}
340
341/// Compute a deterministic maximal fusion subset for a batch of megakernel work items.
342///
343/// `costs[i]` is the dispatch cost of program `i` (lower is cheaper).
344/// `exchange_adj[i*n+j]` is non-zero when fusing `i` and `j` is
345/// incompatible (memory overflow, sync class boundary, etc.).
346///
347/// Returns a 0/1 selection vector of length `n`.
348#[must_use]
349pub fn select_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
350    let mut scratch = FusionSelectionScratch::default();
351    select_fused_subset_into(costs, n, exchange_adj, &mut scratch);
352    scratch.take_result()
353}
354
355/// Compute the optimal fusion subset into reusable scratch buffers.
356pub fn select_fused_subset_into(
357    costs: &[f64],
358    n: u32,
359    exchange_adj: &[u32],
360    scratch: &mut FusionSelectionScratch,
361) {
362    if let Ok((n_usize, _cells)) = validate_selector_shape(costs.len(), n, exchange_adj.len()) {
363        if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
364            scratch.prepare(n_usize);
365            scratch.result.fill(1);
366            return;
367        }
368    }
369    if select_fused_subset_checked_into(costs, n, exchange_adj, scratch).is_err() {
370        scratch.prepare(0);
371    }
372}
373
374/// Checked selector variant that reports malformed planner input.
375pub fn select_fused_subset_checked_into(
376    costs: &[f64],
377    n: u32,
378    exchange_adj: &[u32],
379    scratch: &mut FusionSelectionScratch,
380) -> Result<(), FusionSelectionError> {
381    let (n_usize, _cells) = validate_selector_shape(costs.len(), n, exchange_adj.len())?;
382    if n_usize > MAX_DENSE_FUSION_ITEMS {
383        scratch.prepare(n_usize);
384        scratch.result.fill(1);
385        return Ok(());
386    }
387    scratch.prepare(n_usize);
388    if exchange_adj.iter().all(|&edge| edge == 0) {
389        scratch.result.fill(1);
390        return Ok(());
391    }
392    if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
393    {
394        scratch.result.fill(1);
395        return Ok(());
396    }
397    scratch.order.sort_unstable_by(|&a, &b| {
398        costs[a]
399            .total_cmp(&costs[b])
400            .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
401            .then_with(|| a.cmp(&b))
402    });
403    select_ordered_maximal(
404        exchange_adj,
405        n_usize,
406        &scratch.order,
407        &mut scratch.selected,
408        &mut scratch.result,
409    );
410    Ok(())
411}
412
413/// Compact-cost selector for hot runtime dispatchers.
414///
415/// `costs_q16[i]` is a normalized fixed-point dispatch cost where lower is
416/// cheaper. This avoids carrying `Vec<f64>` scratch through runtime hot paths;
417/// the exact matroid rounder still receives the same exchange graph.
418#[must_use]
419pub fn select_fused_subset_compact(costs_q16: &[u16], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
420    let mut scratch = FusionSelectionScratch::default();
421    select_fused_subset_compact_into(costs_q16, n, exchange_adj, &mut scratch);
422    scratch.take_result()
423}
424
425/// Compact-cost selector using caller-owned scratch buffers.
426pub fn select_fused_subset_compact_into(
427    costs_q16: &[u16],
428    n: u32,
429    exchange_adj: &[u32],
430    scratch: &mut FusionSelectionScratch,
431) {
432    if let Ok((n_usize, _cells)) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len()) {
433        if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
434            scratch.prepare(n_usize);
435            scratch.result.fill(1);
436            return;
437        }
438    }
439    if select_fused_subset_compact_checked_into(costs_q16, n, exchange_adj, scratch).is_err() {
440        scratch.prepare(0);
441    }
442}
443
444/// Checked compact selector variant that reports malformed planner input.
445pub fn select_fused_subset_compact_checked_into(
446    costs_q16: &[u16],
447    n: u32,
448    exchange_adj: &[u32],
449    scratch: &mut FusionSelectionScratch,
450) -> Result<(), FusionSelectionError> {
451    let (n_usize, _cells) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len())?;
452    if n_usize > MAX_DENSE_FUSION_ITEMS {
453        scratch.prepare(n_usize);
454        scratch.result.fill(1);
455        return Ok(());
456    }
457    scratch.prepare(n_usize);
458    if exchange_adj.iter().all(|&edge| edge == 0) {
459        scratch.result.fill(1);
460        return Ok(());
461    }
462    if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
463    {
464        scratch.result.fill(1);
465        return Ok(());
466    }
467    scratch.order.sort_unstable_by(|&a, &b| {
468        costs_q16[a]
469            .cmp(&costs_q16[b])
470            .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
471            .then_with(|| a.cmp(&b))
472    });
473    select_ordered_maximal(
474        exchange_adj,
475        n_usize,
476        &scratch.order,
477        &mut scratch.selected,
478        &mut scratch.result,
479    );
480    Ok(())
481}
482
483/// Compute a cost-ordered maximal fusion subset with the same output contract
484/// as [`select_fused_subset`].
485#[must_use]
486
487pub fn select_optimal_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
488    select_fused_subset(costs, n, exchange_adj)
489}
490
491/// Runtime-compatible selector entry point that preserves the historical API.
492#[must_use]
493pub fn select_fused_subset_with_rate(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
494    select_fused_subset(costs, n, exchange_adj)
495}
496
497/// Select a cost-ordered fused subset, then eliminate arms whose gate
498/// predicates have already proven them to be no-ops for this dispatch.
499///
500/// This is the runtime-facing C5 entry point: it keeps the historical
501/// selection algorithm unchanged, then applies [`prune_dead_arms_inplace`]
502/// before the caller materializes the launch sequence.
503#[must_use]
504pub fn select_fused_subset_pruned(
505    costs: &[f64],
506    n: u32,
507    exchange_adj: &[u32],
508    dead_mask: &[bool],
509) -> Vec<u32> {
510    let mut selection = select_fused_subset(costs, n, exchange_adj);
511    prune_dead_arms_inplace(&mut selection, dead_mask);
512    selection
513}
514
515/// Reusable-scratch variant of [`select_fused_subset_pruned`].
516pub fn select_fused_subset_pruned_into(
517    costs: &[f64],
518    n: u32,
519    exchange_adj: &[u32],
520    dead_mask: &[bool],
521    scratch: &mut FusionSelectionScratch,
522) {
523    select_fused_subset_into(costs, n, exchange_adj, scratch);
524    prune_dead_arms_inplace(&mut scratch.result, dead_mask);
525}
526
527/// ROADMAP C5 substrate: gated no-op middle-arm elimination.
528///
529/// Given a `selection` 0/1 vector (one entry per arm in the megakernel
530/// dispatch sequence) and a `dead_mask` of the same length where
531/// `dead_mask[i] = true` means arm `i` has been proven to be a no-op
532/// at this dispatch (gate predicate folds to false, output equals
533/// input, etc.), zero out the corresponding selection entries in
534/// place. Returns the number of arms eliminated so the caller can
535/// log/telemeter the win.
536///
537/// Length mismatch is a caller contract violation. The function leaves the
538/// selection untouched and returns zero so reusable planner scratch is never
539/// abandoned through a panic while a checked caller records the malformed
540/// planner input.
541///
542/// Example: an inference megakernel where arm 1 is a `mask × value`
543/// step that's gated `mask != 0`. If the static analyzer proves the
544/// mask buffer is all-zero for this batch, dispatch can elide arm 1
545/// entirely. Without this elision the GPU launches a full kernel that
546/// reads both buffers, computes the multiplication, and writes a
547/// zero-result back  -  pure waste.
548pub fn prune_dead_arms_inplace(selection: &mut [u32], dead_mask: &[bool]) -> u32 {
549    if selection.len() != dead_mask.len() {
550        return 0;
551    }
552    let mut eliminated = 0_u32;
553    for (slot, &dead) in selection.iter_mut().zip(dead_mask.iter()) {
554        if dead && *slot != 0 {
555            *slot = 0;
556            eliminated = eliminated.checked_add(1).unwrap_or_else(|| {
557                panic!(
558                    "megakernel dead-arm elimination count overflowed u32. Fix: shard the fusion selection before pruning."
559                )
560            });
561        }
562    }
563    eliminated
564}
565
566fn compute_conflict_degrees_with_conflict(exchange_adj: &[u32], n: usize, out: &mut [u32]) -> bool {
567    debug_assert_eq!(out.len(), n);
568    out.fill(0);
569    let mut has_conflict = false;
570    for i in 0..n {
571        let row = i * n;
572        for j in (i + 1)..n {
573            if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
574                out[i] = increment_degree(out[i]);
575                out[j] = increment_degree(out[j]);
576                has_conflict = true;
577            }
578        }
579    }
580    has_conflict
581}
582
583fn discount_q16(value: u16, amount: u16) -> u16 {
584    value.checked_sub(amount).unwrap_or_else(|| {
585        panic!(
586            "megakernel fusion cost discount underflowed q16 score. Fix: normalize costs before applying fusion discounts."
587        )
588    })
589}
590
591fn increment_degree(value: u32) -> u32 {
592    value.checked_add(1).unwrap_or_else(|| {
593        panic!(
594            "megakernel fusion conflict degree overflowed u32. Fix: shard the exchange graph before planning."
595        )
596    })
597}
598
599fn select_ordered_maximal(
600    exchange_adj: &[u32],
601    n: usize,
602    order: &[usize],
603    selected: &mut Vec<usize>,
604    result: &mut [u32],
605) {
606    result.fill(0);
607    selected.clear();
608
609    if n == 0 {
610        return;
611    }
612
613    if n <= 64 {
614        let mut conflict_masks = [0_u64; 64];
615        for i in 0..n {
616            let row = i * n;
617            let mut mask = 0_u64;
618            for j in 0..n {
619                if i == j {
620                    continue;
621                }
622                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
623                    mask |= 1_u64 << j;
624                }
625            }
626            conflict_masks[i] = mask;
627        }
628
629        let mut selected_mask = 0_u64;
630        for &item in order {
631            if item >= n {
632                continue;
633            }
634            if conflict_masks[item] & selected_mask == 0 {
635                result[item] = 1;
636                selected_mask |= 1_u64 << item;
637                selected.push(item);
638            }
639        }
640        return;
641    }
642
643    if n <= 128 {
644        let mut conflict_masks_lo = [0_u64; 128];
645        let mut conflict_masks_hi = [0_u64; 128];
646        for i in 0..n {
647            let row = i * n;
648            let mut mask_lo = 0_u64;
649            let mut mask_hi = 0_u64;
650            for j in 0..n {
651                if i == j {
652                    continue;
653                }
654                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
655                    if j < 64 {
656                        mask_lo |= 1_u64 << j;
657                    } else {
658                        mask_hi |= 1_u64 << (j - 64);
659                    }
660                }
661            }
662            conflict_masks_lo[i] = mask_lo;
663            conflict_masks_hi[i] = mask_hi;
664        }
665
666        let mut selected_lo = 0_u64;
667        let mut selected_hi = 0_u64;
668        for &item in order {
669            if item >= n {
670                continue;
671            }
672            let conflict = (conflict_masks_lo[item] & selected_lo) != 0
673                || (conflict_masks_hi[item] & selected_hi) != 0;
674            if !conflict {
675                result[item] = 1;
676                if item < 64 {
677                    selected_lo |= 1_u64 << item;
678                } else {
679                    selected_hi |= 1_u64 << (item - 64);
680                }
681                selected.push(item);
682            }
683        }
684        return;
685    }
686
687    if n <= 192 {
688        let mut conflict_masks_0 = [0_u64; 192];
689        let mut conflict_masks_1 = [0_u64; 192];
690        let mut conflict_masks_2 = [0_u64; 192];
691        for i in 0..n {
692            let row = i * n;
693            let mut mask_0 = 0_u64;
694            let mut mask_1 = 0_u64;
695            let mut mask_2 = 0_u64;
696            for j in 0..n {
697                if i == j {
698                    continue;
699                }
700                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
701                    match j / 64 {
702                        0 => mask_0 |= 1_u64 << (j % 64),
703                        1 => mask_1 |= 1_u64 << (j % 64),
704                        2 => mask_2 |= 1_u64 << (j % 64),
705                        _ => {}
706                    }
707                }
708            }
709            conflict_masks_0[i] = mask_0;
710            conflict_masks_1[i] = mask_1;
711            conflict_masks_2[i] = mask_2;
712        }
713
714        let mut selected_0 = 0_u64;
715        let mut selected_1 = 0_u64;
716        let mut selected_2 = 0_u64;
717        for &item in order {
718            if item >= n {
719                continue;
720            }
721            let conflict = (conflict_masks_0[item] & selected_0 != 0)
722                || (conflict_masks_1[item] & selected_1 != 0)
723                || (conflict_masks_2[item] & selected_2 != 0);
724            if !conflict {
725                result[item] = 1;
726                let bit = 1_u64 << (item % 64);
727                match item / 64 {
728                    0 => selected_0 |= bit,
729                    1 => selected_1 |= bit,
730                    2 => selected_2 |= bit,
731                    _ => {}
732                }
733                selected.push(item);
734            }
735        }
736        return;
737    }
738
739    if n <= 256 {
740        let mut conflict_masks_0 = [0_u64; 256];
741        let mut conflict_masks_1 = [0_u64; 256];
742        let mut conflict_masks_2 = [0_u64; 256];
743        let mut conflict_masks_3 = [0_u64; 256];
744        for i in 0..n {
745            let row = i * n;
746            let mut mask_0 = 0_u64;
747            let mut mask_1 = 0_u64;
748            let mut mask_2 = 0_u64;
749            let mut mask_3 = 0_u64;
750            for j in 0..n {
751                if i == j {
752                    continue;
753                }
754                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
755                    match j / 64 {
756                        0 => mask_0 |= 1_u64 << (j % 64),
757                        1 => mask_1 |= 1_u64 << (j % 64),
758                        2 => mask_2 |= 1_u64 << (j % 64),
759                        _ => mask_3 |= 1_u64 << (j % 64),
760                    }
761                }
762            }
763            conflict_masks_0[i] = mask_0;
764            conflict_masks_1[i] = mask_1;
765            conflict_masks_2[i] = mask_2;
766            conflict_masks_3[i] = mask_3;
767        }
768
769        let mut selected_0 = 0_u64;
770        let mut selected_1 = 0_u64;
771        let mut selected_2 = 0_u64;
772        let mut selected_3 = 0_u64;
773        for &item in order {
774            if item >= n {
775                continue;
776            }
777            let conflict = (conflict_masks_0[item] & selected_0 != 0)
778                || (conflict_masks_1[item] & selected_1 != 0)
779                || (conflict_masks_2[item] & selected_2 != 0)
780                || (conflict_masks_3[item] & selected_3 != 0);
781            if !conflict {
782                result[item] = 1;
783                let bit = 1_u64 << (item % 64);
784                match item / 64 {
785                    0 => selected_0 |= bit,
786                    1 => selected_1 |= bit,
787                    2 => selected_2 |= bit,
788                    _ => selected_3 |= bit,
789                }
790                selected.push(item);
791            }
792        }
793        return;
794    }
795
796    let chunks = n.div_ceil(64);
797    let mut conflict_masks = vec![0_u64; n * chunks];
798    for i in 0..n {
799        for j in (i + 1)..n {
800            if exchange_adj[i * n + j] != 0 || exchange_adj[j * n + i] != 0 {
801                let i_word = i / 64;
802                let i_bit = 1_u64 << (i % 64);
803                let j_word = j / 64;
804                let j_bit = 1_u64 << (j % 64);
805
806                let i_base = i * chunks;
807                let j_base = j * chunks;
808                conflict_masks[i_base + j_word] |= j_bit;
809                conflict_masks[j_base + i_word] |= i_bit;
810            }
811        }
812    }
813
814    let mut selected_mask = vec![0_u64; chunks];
815    for &item in order {
816        if item >= n {
817            continue;
818        }
819        let base = item * chunks;
820        let mut conflict = false;
821        for chunk in 0..chunks {
822            if conflict_masks[base + chunk] & selected_mask[chunk] != 0 {
823                conflict = true;
824                break;
825            }
826        }
827        if !conflict {
828            result[item] = 1;
829            selected.push(item);
830            selected_mask[item / 64] |= 1_u64 << (item % 64);
831        }
832    }
833}
834
835#[cfg(test)]
836mod tests;
837