Skip to main content

vyre_reference/
subgroup.rs

1//! Subgroup simulator for the CPU reference interpreter.
2//!
3//! The reference interpreter executes one invocation at a time, but Cat-C
4//! subgroup ops need lane-collective semantics to stay oracle-worthy once the
5//! lowering path emits backend subgroup expressions. This module provides a small,
6//! deterministic simulator over a logical subgroup of lanes.
7
8/// Deterministic CPU model of a hardware subgroup/wave.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct SubgroupSimulator {
11    width: usize,
12}
13
14impl Default for SubgroupSimulator {
15    fn default() -> Self {
16        Self { width: 32 }
17    }
18}
19
20impl SubgroupSimulator {
21    /// Construct a simulator with a fixed subgroup width.
22    #[must_use]
23    pub fn new(width: usize) -> Self {
24        Self {
25            width: width.max(1),
26        }
27    }
28
29    /// Configured subgroup width.
30    #[must_use]
31    pub const fn width(&self) -> usize {
32        self.width
33    }
34
35    /// Encode lane predicates as a ballot bitmask.
36    #[must_use]
37    pub fn ballot<const N: usize>(&self, mask: &[bool; N]) -> u32 {
38        self.ballot_slice(mask)
39    }
40
41    /// Encode an arbitrary lane predicate slice as a ballot bitmask.
42    #[must_use]
43    pub fn ballot_slice(&self, mask: &[bool]) -> u32 {
44        let active = mask.len().min(self.width).min(32);
45        let mut bits = 0u32;
46        for (lane, &flag) in mask.iter().take(active).enumerate() {
47            if flag {
48                bits |= 1u32 << lane;
49            }
50        }
51        bits
52    }
53
54    /// Permute values by source-lane indices.
55    #[must_use]
56    pub fn shuffle(&self, values: &[u32], src_lanes: &[u32]) -> Vec<u32> {
57        let active = values.len().min(src_lanes.len()).min(self.width);
58        src_lanes
59            .iter()
60            .take(active)
61            .map(|&src| values.get(src as usize).copied().unwrap_or(0))
62            .collect()
63    }
64
65    /// Wrapping sum reduction across active lanes.
66    #[must_use]
67    pub fn add(&self, values: &[u32]) -> u32 {
68        values
69            .iter()
70            .take(self.width)
71            .copied()
72            .fold(0u32, u32::wrapping_add)
73    }
74
75    /// Bounds of the subgroup containing `lane_index` within `lane_count`.
76    #[must_use]
77    pub fn subgroup_bounds(&self, lane_count: usize, lane_index: usize) -> (usize, usize) {
78        let start = (lane_index / self.width) * self.width;
79        let end = lane_count.min(start + self.width);
80        (start, end)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::SubgroupSimulator;
87    use proptest::prelude::*;
88    use rayon::prelude::*;
89
90    #[test]
91    fn ballot_sets_expected_bits() {
92        let simulator = SubgroupSimulator::default();
93        assert_eq!(simulator.ballot(&[true, false, true, true]), 0b1101);
94    }
95
96    #[test]
97    fn shuffle_zeroes_out_of_range_lanes() {
98        let simulator = SubgroupSimulator::new(4);
99        assert_eq!(
100            simulator.shuffle(&[10, 20, 30, 40], &[0, 2, 5, 1]),
101            vec![10, 30, 0, 20]
102        );
103    }
104
105    proptest! {
106        #[test]
107        fn subgroup_add_matches_parallel_wrapping_sum(values in prop::collection::vec(any::<u32>(), 0..128)) {
108            let simulator = SubgroupSimulator::new(values.len().max(1));
109            let expected = values.par_iter().copied().reduce(|| 0u32, u32::wrapping_add);
110            prop_assert_eq!(simulator.add(&values), expected);
111        }
112    }
113}