vyre_runtime/megakernel/advanced/
parallel_dfa.rs1use vyre_foundation::ir::{Expr, Node};
7
8const DEFAULT_SUBGROUP_WIDTH: u32 = 32;
9const ALPHABET_SIZE: u32 = 256;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ParallelDfaBindings {
14 pub transitions: &'static str,
16 pub haystack: &'static str,
18 pub lane_prefix: &'static str,
20 pub lane_next: &'static str,
22 pub out_state_by_lane: &'static str,
24 pub file_start: &'static str,
26 pub file_end: &'static str,
28 pub transition_base: &'static str,
30 pub initial_state: &'static str,
32 pub state_count: &'static str,
34 pub subgroup_width: u32,
36}
37
38impl Default for ParallelDfaBindings {
39 fn default() -> Self {
40 Self {
41 transitions: "transitions",
42 haystack: "haystack",
43 lane_prefix: "lane_prefix",
44 lane_next: "lane_next",
45 out_state_by_lane: "out_state_by_lane",
46 file_start: "file_start",
47 file_end: "file_end",
48 transition_base: "transition_base",
49 initial_state: "initial_state",
50 state_count: "state_count",
51 subgroup_width: DEFAULT_SUBGROUP_WIDTH,
52 }
53 }
54}
55
56#[must_use]
61pub fn dfa_byte_scanner_parallel_composition() -> Vec<Node> {
62 dfa_byte_scanner_parallel_composition_with(&ParallelDfaBindings::default())
63}
64
65#[must_use]
73pub fn dfa_byte_scanner_parallel_composition_with(bindings: &ParallelDfaBindings) -> Vec<Node> {
74 let mut nodes = vec![
75 Node::let_bind("lane_id", Expr::invocation_local_x()),
76 Node::let_bind(
77 "lane_byte_pos",
78 Expr::add(Expr::var(bindings.file_start), Expr::var("lane_id")),
79 ),
80 Node::let_bind(
81 "lane_active",
82 Expr::lt(Expr::var("lane_byte_pos"), Expr::var(bindings.file_end)),
83 ),
84 Node::let_bind(
85 "lane_byte",
86 Expr::select(
87 Expr::var("lane_active"),
88 Expr::cast(
89 vyre_foundation::ir::DataType::U32,
90 Expr::load(bindings.haystack, Expr::var("lane_byte_pos")),
91 ),
92 Expr::u32(0),
93 ),
94 ),
95 Node::loop_for(
96 "state",
97 Expr::u32(0),
98 Expr::var(bindings.state_count),
99 vec![Node::store(
100 bindings.lane_prefix,
101 table_index("lane_id", bindings.state_count, Expr::var("state")),
102 Expr::select(
103 Expr::var("lane_active"),
104 Expr::load(
105 bindings.transitions,
106 Expr::add(
107 Expr::var(bindings.transition_base),
108 Expr::add(
109 Expr::mul(Expr::var("state"), Expr::u32(ALPHABET_SIZE)),
110 Expr::var("lane_byte"),
111 ),
112 ),
113 ),
114 Expr::var("state"),
115 ),
116 )],
117 ),
118 Node::barrier(),
119 ];
120
121 let mut stride = 1;
122 while stride < bindings.subgroup_width {
123 append_prefix_stage(&mut nodes, bindings, stride);
124 stride *= 2;
125 }
126
127 nodes.extend([
128 Node::store(
129 bindings.out_state_by_lane,
130 Expr::var("lane_id"),
131 Expr::load(
132 bindings.lane_prefix,
133 table_index(
134 "lane_id",
135 bindings.state_count,
136 Expr::var(bindings.initial_state),
137 ),
138 ),
139 ),
140 Node::barrier(),
141 ]);
142 nodes
143}
144
145fn append_prefix_stage(nodes: &mut Vec<Node>, bindings: &ParallelDfaBindings, stride: u32) {
146 nodes.push(Node::loop_for(
147 "state",
148 Expr::u32(0),
149 Expr::var(bindings.state_count),
150 vec![
151 Node::let_bind(
152 "source_lane",
153 Expr::select(
154 Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
155 Expr::sub(Expr::var("lane_id"), Expr::u32(stride)),
156 Expr::u32(0),
157 ),
158 ),
159 Node::let_bind(
160 "previous_state",
161 Expr::subgroup_shuffle(
162 Expr::load(
163 bindings.lane_prefix,
164 table_index("lane_id", bindings.state_count, Expr::var("state")),
165 ),
166 Expr::var("source_lane"),
167 ),
168 ),
169 Node::let_bind(
170 "composed_state",
171 Expr::select(
172 Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
173 Expr::load(
174 bindings.lane_prefix,
175 table_index("lane_id", bindings.state_count, Expr::var("previous_state")),
176 ),
177 Expr::load(
178 bindings.lane_prefix,
179 table_index("lane_id", bindings.state_count, Expr::var("state")),
180 ),
181 ),
182 ),
183 Node::store(
184 bindings.lane_next,
185 table_index("lane_id", bindings.state_count, Expr::var("state")),
186 Expr::var("composed_state"),
187 ),
188 ],
189 ));
190 nodes.push(Node::barrier());
191 nodes.push(Node::loop_for(
192 "state",
193 Expr::u32(0),
194 Expr::var(bindings.state_count),
195 vec![Node::store(
196 bindings.lane_prefix,
197 table_index("lane_id", bindings.state_count, Expr::var("state")),
198 Expr::load(
199 bindings.lane_next,
200 table_index("lane_id", bindings.state_count, Expr::var("state")),
201 ),
202 )],
203 ));
204 nodes.push(Node::barrier());
205}
206
207fn table_index(lane_var: &str, state_count_var: &str, state: Expr) -> Expr {
208 Expr::add(
209 Expr::mul(Expr::var(lane_var), Expr::var(state_count_var)),
210 state,
211 )
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn parallel_dfa_fragment_has_prefix_barriers_and_output() {
220 let nodes = dfa_byte_scanner_parallel_composition();
221 assert!(
222 nodes
223 .iter()
224 .filter(|node| matches!(
225 node,
226 Node::Barrier {
227 ordering: vyre::memory_model::MemoryOrdering::SeqCst
228 }
229 ))
230 .count()
231 >= 2,
232 "prefix composition must synchronize scratch-table stages"
233 );
234 assert!(
235 stores_buffer(&nodes, "out_state_by_lane"),
236 "fragment must publish per-lane states"
237 );
238 }
239
240 fn stores_buffer(nodes: &[Node], name: &str) -> bool {
241 nodes.iter().any(|node| match node {
242 Node::Store { buffer, .. } => buffer.as_str() == name,
243 Node::Block(body) | Node::Loop { body, .. } => stores_buffer(body, name),
244 Node::Region { body, .. } => stores_buffer(body, name),
245 Node::If {
246 then, otherwise, ..
247 } => stores_buffer(then, name) || stores_buffer(otherwise, name),
248 _ => false,
249 })
250 }
251}