Skip to main content

vyre_runtime/megakernel/planner/
fusion.rs

1//! Fusion-subset selection used by megakernel batch dispatchers.
2//!
3//! This runtime path is deliberately self-contained: it does not call
4//! self-substrate CPU reference solvers while preparing megakernel work.
5
6use super::MegakernelWorkItem;
7
8mod prologue;
9pub use prologue::shared_prologue_length;
10
11/// Hard cap for dense exchange-graph planning.
12///
13/// This avoids dense O(n*n) matrix growth in pathological batches.
14pub(super) const MAX_DENSE_FUSION_ITEMS: usize = 4096;
15
16/// Reusable buffers for megakernel fusion-subset selection.
17///
18/// Runtime schedulers can keep one scratch object per worker and avoid
19/// allocating the homotopy, seed, flow, and result buffers every batch.
20#[derive(Debug, Default)]
21pub struct FusionSelectionScratch {
22    order: Vec<usize>,
23    result: Vec<u32>,
24    conflict_degrees: Vec<u32>,
25    selected: Vec<usize>,
26}
27
28impl FusionSelectionScratch {
29    /// Selected 0/1 fusion vector from the last selector invocation.
30    #[must_use]
31    pub fn result(&self) -> &[u32] {
32        &self.result
33    }
34
35    /// Move out the current result while retaining the other scratch buffers.
36    #[must_use]
37    pub fn take_result(&mut self) -> Vec<u32> {
38        std::mem::take(&mut self.result)
39    }
40
41    fn prepare(&mut self, n: usize) {
42        self.order.clear();
43        self.order.extend(0..n);
44        self.result.clear();
45        self.result.resize(n, 0);
46        self.conflict_degrees.clear();
47        self.conflict_degrees.resize(n, 0);
48        self.selected.clear();
49    }
50}
51
52/// Input-shape error from megakernel fusion subset selection.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum FusionSelectionError {
55    /// `n * n` overflowed `usize`.
56    ExchangeSizeOverflow {
57        /// Requested item count.
58        n: usize,
59    },
60    /// Cost vector length did not match `n`.
61    CostLen {
62        /// Expected number of costs.
63        expected: usize,
64        /// Actual number of costs.
65        actual: usize,
66    },
67    /// Exchange adjacency length did not match `n * n`.
68    ExchangeAdjLen {
69        /// Expected number of row-major adjacency cells.
70        expected: usize,
71        /// Actual number of adjacency cells.
72        actual: usize,
73    },
74}
75
76impl std::fmt::Display for FusionSelectionError {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Self::ExchangeSizeOverflow { n } => write!(
80                f,
81                "megakernel fusion selector n*n overflow for n={n}. Fix: shard the work batch before fusion selection."
82            ),
83            Self::CostLen { expected, actual } => write!(
84                f,
85                "megakernel fusion selector cost length {actual} does not match n={expected}. Fix: pass one cost per work item."
86            ),
87            Self::ExchangeAdjLen { expected, actual } => write!(
88                f,
89                "megakernel fusion selector exchange_adj length {actual} does not match n*n={expected}. Fix: pass a dense row-major n*n exchange graph."
90            ),
91        }
92    }
93}
94
95impl std::error::Error for FusionSelectionError {}
96
97fn validate_selector_shape(
98    cost_len: usize,
99    n: u32,
100    exchange_adj_len: usize,
101) -> Result<(usize, usize), FusionSelectionError> {
102    let n_usize = usize::try_from(n)
103        .map_err(|_| FusionSelectionError::ExchangeSizeOverflow { n: usize::MAX })?;
104    let cells = n_usize
105        .checked_mul(n_usize)
106        .ok_or(FusionSelectionError::ExchangeSizeOverflow { n: n_usize })?;
107    if cost_len != n_usize {
108        return Err(FusionSelectionError::CostLen {
109            expected: n_usize,
110            actual: cost_len,
111        });
112    }
113    if exchange_adj_len != cells {
114        return Err(FusionSelectionError::ExchangeAdjLen {
115            expected: cells,
116            actual: exchange_adj_len,
117        });
118    }
119    Ok((n_usize, cells))
120}
121
122/// Reusable scratch for compact runtime fusion planning.
123///
124/// Concrete drivers own command submission. Runtime owns the queue-shaping
125/// policy: cost seeds, divergence flags, exchange graph, and selector output.
126#[derive(Debug, Default)]
127pub struct CompactFusionPlanningScratch {
128    costs_q16: Vec<u16>,
129    stalks: Vec<f32>,
130    diffused_stalks: Vec<f32>,
131    effective_divergence: Vec<u32>,
132    deltas: Vec<f32>,
133    sorted_deltas: Vec<f32>,
134    exchange_adj: Vec<u32>,
135    order: Vec<usize>,
136    selection: FusionSelectionScratch,
137}
138
139impl CompactFusionPlanningScratch {
140    /// Last exchange adjacency matrix, row-major `n*n`.
141    #[must_use]
142    pub fn exchange_adj(&self) -> &[u32] {
143        &self.exchange_adj
144    }
145
146    /// Last 0/1 selection vector.
147    #[must_use]
148    pub fn selected(&self) -> &[u32] {
149        self.selection.result()
150    }
151}
152
153/// Build the compact megakernel fusion plan for one work batch.
154///
155/// Returns the selector's 0/1 keep vector. The matching exchange adjacency is
156/// retained in `scratch.exchange_adj()` for provenance and diagnostics.
157pub fn plan_compact_fusion_into<'a>(
158    work_items: &[MegakernelWorkItem],
159    scratch: &'a mut CompactFusionPlanningScratch,
160) -> &'a [u32] {
161    let n = work_items.len();
162    if n > MAX_DENSE_FUSION_ITEMS {
163        scratch.selection.prepare(n);
164        scratch.selection.result.fill(1);
165        scratch.exchange_adj.clear();
166        return scratch.selection.result();
167    }
168
169    if n == 0 {
170        scratch.costs_q16.clear();
171        scratch.stalks.clear();
172        scratch.diffused_stalks.clear();
173        scratch.effective_divergence.clear();
174        scratch.deltas.clear();
175        scratch.sorted_deltas.clear();
176        scratch.exchange_adj.clear();
177        scratch.selection.prepare(0);
178        return scratch.selection.result();
179    }
180
181    scratch.costs_q16.clear();
182    scratch.costs_q16.resize(n, u16::MAX);
183
184    scratch.stalks.clear();
185    scratch.stalks.extend(
186        work_items
187            .iter()
188            .enumerate()
189            .map(|(item_idx, _item)| (item_idx as f32) * 0.001),
190    );
191    scratch.diffused_stalks.clear();
192    scratch.diffused_stalks.extend_from_slice(&scratch.stalks);
193    for _ in 0..8 {
194        for value in &mut scratch.diffused_stalks {
195            *value -= 0.5_f32 * 0.7_f32 * *value;
196        }
197    }
198
199    let divergence_threshold = 0.05_f32;
200    let mut delta_sum = 0.0_f32;
201    let mut delta_max = 0.0_f32;
202    scratch.effective_divergence.clear();
203    for (&initial, &diffused) in scratch.stalks.iter().zip(scratch.diffused_stalks.iter()) {
204        let delta = (initial - diffused).abs();
205        delta_sum += delta;
206        delta_max = delta_max.max(delta);
207        scratch
208            .effective_divergence
209            .push(u32::from(delta > divergence_threshold));
210    }
211
212    let n_f32 = n as f32;
213    let gap_signal = if delta_max > 0.0_f32 && n_f32 > 0.0_f32 {
214        delta_sum / (n_f32 * delta_max)
215    } else {
216        1.0_f32
217    };
218    if gap_signal < 0.3 {
219        scratch.deltas.clear();
220        scratch.deltas.extend(
221            scratch
222                .stalks
223                .iter()
224                .zip(scratch.diffused_stalks.iter())
225                .map(|(s, d)| (s - d).abs()),
226        );
227        scratch.sorted_deltas.clear();
228        scratch.sorted_deltas.extend_from_slice(&scratch.deltas);
229        scratch
230            .sorted_deltas
231            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
232        let median = scratch
233            .sorted_deltas
234            .get(scratch.sorted_deltas.len() / 2)
235            .copied()
236            .unwrap_or(0.0);
237        for (flag, delta) in scratch
238            .effective_divergence
239            .iter_mut()
240            .zip(scratch.deltas.iter())
241        {
242            if *delta < median {
243                *flag = 0;
244            }
245        }
246    }
247
248    scratch.exchange_adj.clear();
249    let dense_cells = n * n;
250    scratch.exchange_adj.resize(dense_cells, 0);
251    let mut has_exchange_conflict = false;
252
253    let mut has_op_conflict = false;
254    scratch.order.clear();
255    scratch.order.extend(0..n);
256    if scratch.order.len() > 1 {
257        scratch
258            .order
259            .sort_unstable_by_key(|&item_idx| work_items[item_idx].op_handle);
260        if scratch
261            .order
262            .windows(2)
263            .any(|window| work_items[window[0]].op_handle == work_items[window[1]].op_handle)
264        {
265            has_op_conflict = true;
266        }
267    }
268    let has_output_input_chain = (0..n.checked_sub(1).unwrap_or(0)).any(|i| {
269        work_items.get(i).map(|w| w.output_handle) == work_items.get(i + 1).map(|w| w.input_handle)
270    });
271    let has_divergence_conflict = scratch.effective_divergence.iter().any(|&v| v != 0);
272    scratch.selection.prepare(n);
273
274    if !has_op_conflict && !has_divergence_conflict {
275        if has_output_input_chain {
276            for cost in scratch.costs_q16.iter_mut() {
277                *cost = discount_q16(*cost, 3_276);
278            }
279        }
280
281        scratch.selection.result.fill(1);
282        return scratch.selection.result();
283    }
284
285    {
286        let conflict_degrees = &mut scratch.selection.conflict_degrees;
287        for i in 0..n {
288            let row_start = i * n;
289            for j in 0..n {
290                if i == j {
291                    continue;
292                }
293                let same_op = work_items[i].op_handle == work_items[j].op_handle;
294                if n <= 32 && same_op {
295                    scratch.costs_q16[i] = discount_q16(scratch.costs_q16[i], 3_276);
296                }
297                let divergent =
298                    scratch.effective_divergence[i] != 0 && scratch.effective_divergence[j] != 0;
299                if same_op || divergent {
300                    scratch.exchange_adj[row_start + j] = 1;
301                    if i < j {
302                        conflict_degrees[i] = increment_degree(conflict_degrees[i]);
303                        conflict_degrees[j] = increment_degree(conflict_degrees[j]);
304                    }
305                    has_exchange_conflict = true;
306                }
307            }
308        }
309    }
310    if has_output_input_chain {
311        for cost in scratch.costs_q16.iter_mut() {
312            *cost = discount_q16(*cost, 3_276);
313        }
314    }
315    if !has_exchange_conflict {
316        scratch.selection.result.fill(1);
317        return scratch.selection.result();
318    }
319
320    let conflict_degrees = &scratch.selection.conflict_degrees;
321    scratch.selection.order.sort_unstable_by(|&a, &b| {
322        scratch.costs_q16[a]
323            .cmp(&scratch.costs_q16[b])
324            .then_with(|| conflict_degrees[a].cmp(&conflict_degrees[b]))
325            .then_with(|| a.cmp(&b))
326    });
327    select_ordered_maximal(
328        &scratch.exchange_adj,
329        n,
330        &scratch.selection.order,
331        &mut scratch.selection.selected,
332        &mut scratch.selection.result,
333    );
334    scratch.selection.result()
335}
336
337/// Compute a deterministic maximal fusion subset for a batch of megakernel work items.
338///
339/// `costs[i]` is the dispatch cost of program `i` (lower is cheaper).
340/// `exchange_adj[i*n+j]` is non-zero when fusing `i` and `j` is
341/// incompatible (memory overflow, sync class boundary, etc.).
342///
343/// Returns a 0/1 selection vector of length `n`.
344#[must_use]
345pub fn select_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
346    let mut scratch = FusionSelectionScratch::default();
347    select_fused_subset_into(costs, n, exchange_adj, &mut scratch);
348    scratch.take_result()
349}
350
351/// Compute the optimal fusion subset into reusable scratch buffers.
352pub fn select_fused_subset_into(
353    costs: &[f64],
354    n: u32,
355    exchange_adj: &[u32],
356    scratch: &mut FusionSelectionScratch,
357) {
358    if let Ok((n_usize, _cells)) = validate_selector_shape(costs.len(), n, exchange_adj.len()) {
359        if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
360            scratch.prepare(n_usize);
361            scratch.result.fill(1);
362            return;
363        }
364    }
365    if select_fused_subset_checked_into(costs, n, exchange_adj, scratch).is_err() {
366        scratch.prepare(0);
367    }
368}
369
370/// Checked selector variant that reports malformed planner input.
371pub fn select_fused_subset_checked_into(
372    costs: &[f64],
373    n: u32,
374    exchange_adj: &[u32],
375    scratch: &mut FusionSelectionScratch,
376) -> Result<(), FusionSelectionError> {
377    let (n_usize, _cells) = validate_selector_shape(costs.len(), n, exchange_adj.len())?;
378    if n_usize > MAX_DENSE_FUSION_ITEMS {
379        scratch.prepare(n_usize);
380        scratch.result.fill(1);
381        return Ok(());
382    }
383    scratch.prepare(n_usize);
384    if exchange_adj.iter().all(|&edge| edge == 0) {
385        scratch.result.fill(1);
386        return Ok(());
387    }
388    if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
389    {
390        scratch.result.fill(1);
391        return Ok(());
392    }
393    scratch.order.sort_unstable_by(|&a, &b| {
394        costs[a]
395            .total_cmp(&costs[b])
396            .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
397            .then_with(|| a.cmp(&b))
398    });
399    select_ordered_maximal(
400        exchange_adj,
401        n_usize,
402        &scratch.order,
403        &mut scratch.selected,
404        &mut scratch.result,
405    );
406    Ok(())
407}
408
409/// Compact-cost selector for hot runtime dispatchers.
410///
411/// `costs_q16[i]` is a normalized fixed-point dispatch cost where lower is
412/// cheaper. This avoids carrying `Vec<f64>` scratch through runtime hot paths;
413/// the exact matroid rounder still receives the same exchange graph.
414#[must_use]
415pub fn select_fused_subset_compact(costs_q16: &[u16], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
416    let mut scratch = FusionSelectionScratch::default();
417    select_fused_subset_compact_into(costs_q16, n, exchange_adj, &mut scratch);
418    scratch.take_result()
419}
420
421/// Compact-cost selector using caller-owned scratch buffers.
422pub fn select_fused_subset_compact_into(
423    costs_q16: &[u16],
424    n: u32,
425    exchange_adj: &[u32],
426    scratch: &mut FusionSelectionScratch,
427) {
428    if let Ok((n_usize, _cells)) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len()) {
429        if n_usize <= MAX_DENSE_FUSION_ITEMS && exchange_adj.iter().all(|&edge| edge == 0) {
430            scratch.prepare(n_usize);
431            scratch.result.fill(1);
432            return;
433        }
434    }
435    if select_fused_subset_compact_checked_into(costs_q16, n, exchange_adj, scratch).is_err() {
436        scratch.prepare(0);
437    }
438}
439
440/// Checked compact selector variant that reports malformed planner input.
441pub fn select_fused_subset_compact_checked_into(
442    costs_q16: &[u16],
443    n: u32,
444    exchange_adj: &[u32],
445    scratch: &mut FusionSelectionScratch,
446) -> Result<(), FusionSelectionError> {
447    let (n_usize, _cells) = validate_selector_shape(costs_q16.len(), n, exchange_adj.len())?;
448    if n_usize > MAX_DENSE_FUSION_ITEMS {
449        scratch.prepare(n_usize);
450        scratch.result.fill(1);
451        return Ok(());
452    }
453    scratch.prepare(n_usize);
454    if exchange_adj.iter().all(|&edge| edge == 0) {
455        scratch.result.fill(1);
456        return Ok(());
457    }
458    if !compute_conflict_degrees_with_conflict(exchange_adj, n_usize, &mut scratch.conflict_degrees)
459    {
460        scratch.result.fill(1);
461        return Ok(());
462    }
463    scratch.order.sort_unstable_by(|&a, &b| {
464        costs_q16[a]
465            .cmp(&costs_q16[b])
466            .then_with(|| scratch.conflict_degrees[a].cmp(&scratch.conflict_degrees[b]))
467            .then_with(|| a.cmp(&b))
468    });
469    select_ordered_maximal(
470        exchange_adj,
471        n_usize,
472        &scratch.order,
473        &mut scratch.selected,
474        &mut scratch.result,
475    );
476    Ok(())
477}
478
479/// Compute a cost-ordered maximal fusion subset with the same output contract
480/// as [`select_fused_subset`].
481#[must_use]
482
483pub fn select_optimal_fused_subset(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
484    select_fused_subset(costs, n, exchange_adj)
485}
486
487/// Runtime-compatible selector entry point that preserves the historical API.
488#[must_use]
489pub fn select_fused_subset_with_rate(costs: &[f64], n: u32, exchange_adj: &[u32]) -> Vec<u32> {
490    select_fused_subset(costs, n, exchange_adj)
491}
492
493/// Select a cost-ordered fused subset, then eliminate arms whose gate
494/// predicates have already proven them to be no-ops for this dispatch.
495///
496/// This is the runtime-facing C5 entry point: it keeps the historical
497/// selection algorithm unchanged, then applies [`prune_dead_arms_inplace`]
498/// before the caller materializes the launch sequence.
499#[must_use]
500pub fn select_fused_subset_pruned(
501    costs: &[f64],
502    n: u32,
503    exchange_adj: &[u32],
504    dead_mask: &[bool],
505) -> Vec<u32> {
506    let mut selection = select_fused_subset(costs, n, exchange_adj);
507    prune_dead_arms_inplace(&mut selection, dead_mask);
508    selection
509}
510
511/// Reusable-scratch variant of [`select_fused_subset_pruned`].
512pub fn select_fused_subset_pruned_into(
513    costs: &[f64],
514    n: u32,
515    exchange_adj: &[u32],
516    dead_mask: &[bool],
517    scratch: &mut FusionSelectionScratch,
518) {
519    select_fused_subset_into(costs, n, exchange_adj, scratch);
520    prune_dead_arms_inplace(&mut scratch.result, dead_mask);
521}
522
523/// ROADMAP C5 substrate: gated no-op middle-arm elimination.
524///
525/// Given a `selection` 0/1 vector (one entry per arm in the megakernel
526/// dispatch sequence) and a `dead_mask` of the same length where
527/// `dead_mask[i] = true` means arm `i` has been proven to be a no-op
528/// at this dispatch (gate predicate folds to false, output equals
529/// input, etc.), zero out the corresponding selection entries in
530/// place. Returns the number of arms eliminated so the caller can
531/// log/telemeter the win.
532///
533/// Length mismatch is a caller contract violation. The function leaves the
534/// selection untouched and returns zero so reusable planner scratch is never
535/// abandoned through a panic while a checked caller records the malformed
536/// planner input.
537///
538/// Example: an inference megakernel where arm 1 is a `mask × value`
539/// step that's gated `mask != 0`. If the static analyzer proves the
540/// mask buffer is all-zero for this batch, dispatch can elide arm 1
541/// entirely. Without this elision the GPU launches a full kernel that
542/// reads both buffers, computes the multiplication, and writes a
543/// zero-result back  -  pure waste.
544pub fn prune_dead_arms_inplace(selection: &mut [u32], dead_mask: &[bool]) -> u32 {
545    if selection.len() != dead_mask.len() {
546        return 0;
547    }
548    let mut eliminated = 0_u32;
549    for (slot, &dead) in selection.iter_mut().zip(dead_mask.iter()) {
550        if dead && *slot != 0 {
551            *slot = 0;
552            eliminated = eliminated.saturating_add(1);
553        }
554    }
555    eliminated
556}
557
558fn compute_conflict_degrees_with_conflict(exchange_adj: &[u32], n: usize, out: &mut [u32]) -> bool {
559    debug_assert_eq!(out.len(), n);
560    out.fill(0);
561    let mut has_conflict = false;
562    for i in 0..n {
563        let row = i * n;
564        for j in (i + 1)..n {
565            if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
566                out[i] = increment_degree(out[i]);
567                out[j] = increment_degree(out[j]);
568                has_conflict = true;
569            }
570        }
571    }
572    has_conflict
573}
574
575fn discount_q16(value: u16, amount: u16) -> u16 {
576    value.saturating_sub(amount)
577}
578
579fn increment_degree(value: u32) -> u32 {
580    value.saturating_add(1)
581}
582
583fn select_ordered_maximal(
584    exchange_adj: &[u32],
585    n: usize,
586    order: &[usize],
587    selected: &mut Vec<usize>,
588    result: &mut [u32],
589) {
590    result.fill(0);
591    selected.clear();
592
593    if n == 0 {
594        return;
595    }
596
597    if n <= 64 {
598        let mut conflict_masks = [0_u64; 64];
599        for i in 0..n {
600            let row = i * n;
601            let mut mask = 0_u64;
602            for j in 0..n {
603                if i == j {
604                    continue;
605                }
606                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
607                    mask |= 1_u64 << j;
608                }
609            }
610            conflict_masks[i] = mask;
611        }
612
613        let mut selected_mask = 0_u64;
614        for &item in order {
615            if item >= n {
616                continue;
617            }
618            if conflict_masks[item] & selected_mask == 0 {
619                result[item] = 1;
620                selected_mask |= 1_u64 << item;
621                selected.push(item);
622            }
623        }
624        return;
625    }
626
627    if n <= 128 {
628        let mut conflict_masks_lo = [0_u64; 128];
629        let mut conflict_masks_hi = [0_u64; 128];
630        for i in 0..n {
631            let row = i * n;
632            let mut mask_lo = 0_u64;
633            let mut mask_hi = 0_u64;
634            for j in 0..n {
635                if i == j {
636                    continue;
637                }
638                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
639                    if j < 64 {
640                        mask_lo |= 1_u64 << j;
641                    } else {
642                        mask_hi |= 1_u64 << (j - 64);
643                    }
644                }
645            }
646            conflict_masks_lo[i] = mask_lo;
647            conflict_masks_hi[i] = mask_hi;
648        }
649
650        let mut selected_lo = 0_u64;
651        let mut selected_hi = 0_u64;
652        for &item in order {
653            if item >= n {
654                continue;
655            }
656            let conflict = (conflict_masks_lo[item] & selected_lo) != 0
657                || (conflict_masks_hi[item] & selected_hi) != 0;
658            if !conflict {
659                result[item] = 1;
660                if item < 64 {
661                    selected_lo |= 1_u64 << item;
662                } else {
663                    selected_hi |= 1_u64 << (item - 64);
664                }
665                selected.push(item);
666            }
667        }
668        return;
669    }
670
671    if n <= 192 {
672        let mut conflict_masks_0 = [0_u64; 192];
673        let mut conflict_masks_1 = [0_u64; 192];
674        let mut conflict_masks_2 = [0_u64; 192];
675        for i in 0..n {
676            let row = i * n;
677            let mut mask_0 = 0_u64;
678            let mut mask_1 = 0_u64;
679            let mut mask_2 = 0_u64;
680            for j in 0..n {
681                if i == j {
682                    continue;
683                }
684                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
685                    match j / 64 {
686                        0 => mask_0 |= 1_u64 << (j % 64),
687                        1 => mask_1 |= 1_u64 << (j % 64),
688                        2 => mask_2 |= 1_u64 << (j % 64),
689                        _ => {}
690                    }
691                }
692            }
693            conflict_masks_0[i] = mask_0;
694            conflict_masks_1[i] = mask_1;
695            conflict_masks_2[i] = mask_2;
696        }
697
698        let mut selected_0 = 0_u64;
699        let mut selected_1 = 0_u64;
700        let mut selected_2 = 0_u64;
701        for &item in order {
702            if item >= n {
703                continue;
704            }
705            let conflict = (conflict_masks_0[item] & selected_0 != 0)
706                || (conflict_masks_1[item] & selected_1 != 0)
707                || (conflict_masks_2[item] & selected_2 != 0);
708            if !conflict {
709                result[item] = 1;
710                let bit = 1_u64 << (item % 64);
711                match item / 64 {
712                    0 => selected_0 |= bit,
713                    1 => selected_1 |= bit,
714                    2 => selected_2 |= bit,
715                    _ => {}
716                }
717                selected.push(item);
718            }
719        }
720        return;
721    }
722
723    if n <= 256 {
724        let mut conflict_masks_0 = [0_u64; 256];
725        let mut conflict_masks_1 = [0_u64; 256];
726        let mut conflict_masks_2 = [0_u64; 256];
727        let mut conflict_masks_3 = [0_u64; 256];
728        for i in 0..n {
729            let row = i * n;
730            let mut mask_0 = 0_u64;
731            let mut mask_1 = 0_u64;
732            let mut mask_2 = 0_u64;
733            let mut mask_3 = 0_u64;
734            for j in 0..n {
735                if i == j {
736                    continue;
737                }
738                if exchange_adj[row + j] != 0 || exchange_adj[j * n + i] != 0 {
739                    match j / 64 {
740                        0 => mask_0 |= 1_u64 << (j % 64),
741                        1 => mask_1 |= 1_u64 << (j % 64),
742                        2 => mask_2 |= 1_u64 << (j % 64),
743                        _ => mask_3 |= 1_u64 << (j % 64),
744                    }
745                }
746            }
747            conflict_masks_0[i] = mask_0;
748            conflict_masks_1[i] = mask_1;
749            conflict_masks_2[i] = mask_2;
750            conflict_masks_3[i] = mask_3;
751        }
752
753        let mut selected_0 = 0_u64;
754        let mut selected_1 = 0_u64;
755        let mut selected_2 = 0_u64;
756        let mut selected_3 = 0_u64;
757        for &item in order {
758            if item >= n {
759                continue;
760            }
761            let conflict = (conflict_masks_0[item] & selected_0 != 0)
762                || (conflict_masks_1[item] & selected_1 != 0)
763                || (conflict_masks_2[item] & selected_2 != 0)
764                || (conflict_masks_3[item] & selected_3 != 0);
765            if !conflict {
766                result[item] = 1;
767                let bit = 1_u64 << (item % 64);
768                match item / 64 {
769                    0 => selected_0 |= bit,
770                    1 => selected_1 |= bit,
771                    2 => selected_2 |= bit,
772                    _ => selected_3 |= bit,
773                }
774                selected.push(item);
775            }
776        }
777        return;
778    }
779
780    let chunks = n.div_ceil(64);
781    let mut conflict_masks = vec![0_u64; n * chunks];
782    for i in 0..n {
783        for j in (i + 1)..n {
784            if exchange_adj[i * n + j] != 0 || exchange_adj[j * n + i] != 0 {
785                let i_word = i / 64;
786                let i_bit = 1_u64 << (i % 64);
787                let j_word = j / 64;
788                let j_bit = 1_u64 << (j % 64);
789
790                let i_base = i * chunks;
791                let j_base = j * chunks;
792                conflict_masks[i_base + j_word] |= j_bit;
793                conflict_masks[j_base + i_word] |= i_bit;
794            }
795        }
796    }
797
798    let mut selected_mask = vec![0_u64; chunks];
799    for &item in order {
800        if item >= n {
801            continue;
802        }
803        let base = item * chunks;
804        let mut conflict = false;
805        for chunk in 0..chunks {
806            if conflict_masks[base + chunk] & selected_mask[chunk] != 0 {
807                conflict = true;
808                break;
809            }
810        }
811        if !conflict {
812            result[item] = 1;
813            selected.push(item);
814            selected_mask[item / 64] |= 1_u64 << (item % 64);
815        }
816    }
817}
818
819#[cfg(test)]
820mod tests;