Skip to main content

vyre_runtime/megakernel/advanced/
parallel_dfa.rs

1//! Parallel prefix-composition fragments for DFA scanning.
2//!
3//! Replaces scalar byte-by-byte `loop_for` ($O(N)$) with a subgroup-cooperative
4//! block-stride prefix sum ($O(N/WG_SIZE)$).
5
6use vyre_foundation::ir::{Expr, Node};
7
8const DEFAULT_SUBGROUP_WIDTH: u32 = 32;
9const ALPHABET_SIZE: u32 = 256;
10
11/// Binding names and limits for subgroup DFA prefix composition.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ParallelDfaBindings {
14    /// Dense transition table buffer.
15    pub transitions: &'static str,
16    /// Haystack byte buffer.
17    pub haystack: &'static str,
18    /// Scratch table containing the current per-lane transition function.
19    pub lane_prefix: &'static str,
20    /// Scratch table used for the next prefix-composition stage.
21    pub lane_next: &'static str,
22    /// Output buffer receiving one state per active lane.
23    pub out_state_by_lane: &'static str,
24    /// Variable naming the first byte in the chunk.
25    pub file_start: &'static str,
26    /// Variable naming one-past-last byte in the file/chunk.
27    pub file_end: &'static str,
28    /// Variable naming the rule transition-table base.
29    pub transition_base: &'static str,
30    /// Variable naming the initial DFA state.
31    pub initial_state: &'static str,
32    /// Variable naming the DFA state count.
33    pub state_count: &'static str,
34    /// Subgroup width to compose. Must be a power of two for this fragment.
35    pub subgroup_width: u32,
36}
37
38impl Default for ParallelDfaBindings {
39    fn default() -> Self {
40        Self {
41            transitions: "transitions",
42            haystack: "haystack",
43            lane_prefix: "lane_prefix",
44            lane_next: "lane_next",
45            out_state_by_lane: "out_state_by_lane",
46            file_start: "file_start",
47            file_end: "file_end",
48            transition_base: "transition_base",
49            initial_state: "initial_state",
50            state_count: "state_count",
51            subgroup_width: DEFAULT_SUBGROUP_WIDTH,
52        }
53    }
54}
55
56/// Generate a subgroup prefix-composition DFA fragment using default names.
57///
58/// The caller supplies `lane_prefix` and `lane_next` scratch buffers sized at
59/// `subgroup_width * state_count` `u32` entries.
60#[must_use]
61pub fn dfa_byte_scanner_parallel_composition() -> Vec<Node> {
62    dfa_byte_scanner_parallel_composition_with(&ParallelDfaBindings::default())
63}
64
65/// Generate a concrete subgroup prefix-composition DFA fragment.
66///
67/// Every lane first builds the transition function for its byte, storing an
68/// identity function for inactive lanes past `file_end`. The fixed doubling
69/// stages then compose transition functions with subgroup shuffles and
70/// workgroup barriers. The final per-lane state is written to
71/// `out_state_by_lane[local_x]`.
72#[must_use]
73pub fn dfa_byte_scanner_parallel_composition_with(bindings: &ParallelDfaBindings) -> Vec<Node> {
74    let mut nodes = vec![
75        Node::let_bind("lane_id", Expr::invocation_local_x()),
76        Node::let_bind(
77            "lane_byte_pos",
78            Expr::add(Expr::var(bindings.file_start), Expr::var("lane_id")),
79        ),
80        Node::let_bind(
81            "lane_active",
82            Expr::lt(Expr::var("lane_byte_pos"), Expr::var(bindings.file_end)),
83        ),
84        Node::let_bind(
85            "lane_byte",
86            Expr::select(
87                Expr::var("lane_active"),
88                Expr::cast(
89                    vyre_foundation::ir::DataType::U32,
90                    Expr::load(bindings.haystack, Expr::var("lane_byte_pos")),
91                ),
92                Expr::u32(0),
93            ),
94        ),
95        Node::loop_for(
96            "state",
97            Expr::u32(0),
98            Expr::var(bindings.state_count),
99            vec![Node::store(
100                bindings.lane_prefix,
101                table_index("lane_id", bindings.state_count, Expr::var("state")),
102                Expr::select(
103                    Expr::var("lane_active"),
104                    Expr::load(
105                        bindings.transitions,
106                        Expr::add(
107                            Expr::var(bindings.transition_base),
108                            Expr::add(
109                                Expr::mul(Expr::var("state"), Expr::u32(ALPHABET_SIZE)),
110                                Expr::var("lane_byte"),
111                            ),
112                        ),
113                    ),
114                    Expr::var("state"),
115                ),
116            )],
117        ),
118        Node::barrier(),
119    ];
120
121    let mut stride = 1;
122    while stride < bindings.subgroup_width {
123        append_prefix_stage(&mut nodes, bindings, stride);
124        stride *= 2;
125    }
126
127    nodes.extend([
128        Node::store(
129            bindings.out_state_by_lane,
130            Expr::var("lane_id"),
131            Expr::load(
132                bindings.lane_prefix,
133                table_index(
134                    "lane_id",
135                    bindings.state_count,
136                    Expr::var(bindings.initial_state),
137                ),
138            ),
139        ),
140        Node::barrier(),
141    ]);
142    nodes
143}
144
145fn append_prefix_stage(nodes: &mut Vec<Node>, bindings: &ParallelDfaBindings, stride: u32) {
146    nodes.push(Node::loop_for(
147        "state",
148        Expr::u32(0),
149        Expr::var(bindings.state_count),
150        vec![
151            Node::let_bind(
152                "source_lane",
153                Expr::select(
154                    Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
155                    Expr::sub(Expr::var("lane_id"), Expr::u32(stride)),
156                    Expr::u32(0),
157                ),
158            ),
159            Node::let_bind(
160                "previous_state",
161                Expr::subgroup_shuffle(
162                    Expr::load(
163                        bindings.lane_prefix,
164                        table_index("lane_id", bindings.state_count, Expr::var("state")),
165                    ),
166                    Expr::var("source_lane"),
167                ),
168            ),
169            Node::let_bind(
170                "composed_state",
171                Expr::select(
172                    Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
173                    Expr::load(
174                        bindings.lane_prefix,
175                        table_index("lane_id", bindings.state_count, Expr::var("previous_state")),
176                    ),
177                    Expr::load(
178                        bindings.lane_prefix,
179                        table_index("lane_id", bindings.state_count, Expr::var("state")),
180                    ),
181                ),
182            ),
183            Node::store(
184                bindings.lane_next,
185                table_index("lane_id", bindings.state_count, Expr::var("state")),
186                Expr::var("composed_state"),
187            ),
188        ],
189    ));
190    nodes.push(Node::barrier());
191    nodes.push(Node::loop_for(
192        "state",
193        Expr::u32(0),
194        Expr::var(bindings.state_count),
195        vec![Node::store(
196            bindings.lane_prefix,
197            table_index("lane_id", bindings.state_count, Expr::var("state")),
198            Expr::load(
199                bindings.lane_next,
200                table_index("lane_id", bindings.state_count, Expr::var("state")),
201            ),
202        )],
203    ));
204    nodes.push(Node::barrier());
205}
206
207fn table_index(lane_var: &str, state_count_var: &str, state: Expr) -> Expr {
208    Expr::add(
209        Expr::mul(Expr::var(lane_var), Expr::var(state_count_var)),
210        state,
211    )
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn parallel_dfa_fragment_has_prefix_barriers_and_output() {
220        let nodes = dfa_byte_scanner_parallel_composition();
221        assert!(
222            nodes
223                .iter()
224                .filter(|node| matches!(
225                    node,
226                    Node::Barrier {
227                        ordering: vyre::memory_model::MemoryOrdering::SeqCst
228                    }
229                ))
230                .count()
231                >= 2,
232            "prefix composition must synchronize scratch-table stages"
233        );
234        assert!(
235            stores_buffer(&nodes, "out_state_by_lane"),
236            "fragment must publish per-lane states"
237        );
238    }
239
240    fn stores_buffer(nodes: &[Node], name: &str) -> bool {
241        nodes.iter().any(|node| match node {
242            Node::Store { buffer, .. } => buffer.as_str() == name,
243            Node::Block(body) | Node::Loop { body, .. } => stores_buffer(body, name),
244            Node::Region { body, .. } => stores_buffer(body, name),
245            Node::If {
246                then, otherwise, ..
247            } => stores_buffer(then, name) || stores_buffer(otherwise, name),
248            _ => false,
249        })
250    }
251}