1mod buffers;
11mod input;
12mod readback;
13mod resources;
14mod transitions;
15
16use self::buffers::{entry, storage_buffer};
17use self::input::{checked_input_len, dispatch_groups, DfaParams};
18use self::readback::{read_matches, read_one_u32};
19use self::resources::ScanResources;
20use self::transitions::{build_accept_map, compile_pipeline, validate_compile_inputs};
21use std::sync::Mutex;
22use vyre::error::{Error, Result};
23
24pub(crate) const BYTE_CLASSES: usize = 256;
25pub(crate) const SENTINEL_NO_ACCEPT: u32 = 0xFFFF_FFFF;
26pub const DEFAULT_MAX_MATCHES: u32 = 65_536;
28
29pub const MAX_DFA_MATCHES: u32 = 1_000_000;
31
32#[non_exhaustive]
34pub struct GpuDfa {
35 device: wgpu::Device,
36 compiled_with_cached_device: bool,
37 pipeline: wgpu::ComputePipeline,
38 transition_buffer: wgpu::Buffer,
39 accept_buffer: wgpu::Buffer,
40 pattern_length_buffer: wgpu::Buffer,
41 pattern_lengths: Vec<u32>,
42 state_count: u32,
43 max_matches: u32,
44 resources: Mutex<Vec<ScanResources>>,
45}
46
47impl std::fmt::Debug for GpuDfa {
48 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 formatter
50 .debug_struct("GpuDfa")
51 .field("state_count", &self.state_count)
52 .field("pattern_count", &self.pattern_lengths.len())
53 .field("max_matches", &self.max_matches)
54 .finish_non_exhaustive()
55 }
56}
57
58fn zero_dfa_input_padding(
59 queue: &wgpu::Queue,
60 buffer: &wgpu::Buffer,
61 written: usize,
62) -> Result<()> {
63 let padding_len = written.next_multiple_of(4) - written;
64 if padding_len == 0 {
65 return Ok(());
66 }
67 let offset = u64::try_from(written).map_err(|source| Error::Dfa {
68 message: format!(
69 "DFA input byte count {written} cannot fit u64: {source}. Fix: scan chunks smaller than 4 GiB."
70 ),
71 })?;
72 let padding = [0u8; 4];
73 queue.write_buffer(buffer, offset, &padding[..padding_len]);
74 Ok(())
75}
76
77impl GpuDfa {
78 pub fn compile(
84 device: &wgpu::Device,
85 transitions: &[u32],
86 state_count: usize,
87 accept_states: &[(u32, u32)],
88 output_links: &[u32],
89 pattern_lengths: &[u32],
90 ) -> Result<Self> {
91 Self::compile_with_max_matches(
92 device,
93 transitions,
94 state_count,
95 accept_states,
96 output_links,
97 pattern_lengths,
98 DEFAULT_MAX_MATCHES,
99 )
100 }
101
102 pub fn compile_with_max_matches(
108 device: &wgpu::Device,
109 transitions: &[u32],
110 state_count: usize,
111 accept_states: &[(u32, u32)],
112 output_links: &[u32],
113 pattern_lengths: &[u32],
114 max_matches: u32,
115 ) -> Result<Self> {
116 validate_compile_inputs(
117 device,
118 transitions,
119 state_count,
120 accept_states,
121 output_links,
122 pattern_lengths,
123 max_matches,
124 )?;
125
126 let accept_map = build_accept_map(state_count, accept_states)?;
127 let pipeline = compile_pipeline(device)?;
128 let transition_buffer = storage_buffer(device, "vyre dfa transitions", transitions);
129 let accept_buffer = storage_buffer(device, "vyre dfa accept map", &accept_map);
130 let pattern_length_buffer =
131 storage_buffer(device, "vyre dfa pattern lengths", pattern_lengths);
132
133 Ok(Self {
134 device: device.clone(),
135 compiled_with_cached_device: crate::runtime::device::is_cached_device(device),
136 pipeline,
137 transition_buffer,
138 accept_buffer,
139 pattern_length_buffer,
140 pattern_lengths: pattern_lengths.to_vec(),
141 state_count: u32::try_from(state_count).map_err(|source| Error::Dfa {
142 message: format!(
143 "DFA state_count {state_count} cannot fit u32: {source}. Fix: split the automaton or reduce states."
144 ),
145 })?,
146 max_matches,
147 resources: Mutex::new(Vec::new()),
148 })
149 }
150
151 pub fn scan(
156 &self,
157 device: &wgpu::Device,
158 queue: &wgpu::Queue,
159 input: &[u8],
160 command_encoder: Option<&mut wgpu::CommandEncoder>,
161 ) -> Result<Vec<vyre::Match>> {
162 if *device != self.device {
163 return Err(Error::Dfa {
164 message: "DFA scan device differs from compile device. Fix: scan with the same wgpu::Device and matching Queue used to compile the DFA.".to_string(),
165 });
166 }
167 if input.is_empty() {
168 return Ok(Vec::new());
169 }
170 let input_len = checked_input_len(input)?;
171 let mut resources = self.acquire_scan_resources(input_len)?;
172 let result =
173 self.scan_with_resources(queue, input, input_len, &mut resources, command_encoder);
174 self.release_scan_resources(resources)?;
175 result
176 }
177
178 fn acquire_scan_resources(&self, input_len: u32) -> Result<ScanResources> {
179 let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
180 message: format!("DFA scan resources mutex is poisoned: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
181 })?;
182 if let Some(index) = pool
183 .iter()
184 .position(|resources| resources.max_input_len >= input_len)
185 {
186 return Ok(pool.swap_remove(index));
187 }
188 drop(pool);
189 ScanResources::new(&self.device, input_len, self.max_matches)
190 }
191
192 fn release_scan_resources(&self, resources: ScanResources) -> Result<()> {
193 let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
194 message: format!("DFA scan resources mutex is poisoned while releasing resources: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
195 })?;
196 pool.push(resources);
197 Ok(())
198 }
199
200 pub(crate) fn scan_with_resources(
201 &self,
202 queue: &wgpu::Queue,
203 input: &[u8],
204 input_len: u32,
205 resources: &mut ScanResources,
206 command_encoder: Option<&mut wgpu::CommandEncoder>,
207 ) -> Result<Vec<vyre::Match>> {
208 if input_len > resources.max_input_len {
209 return Err(Error::Dfa {
210 message: format!(
211 "DFA input length {input_len} exceeds ScanResources capacity {}. Fix: create larger ScanResources.",
212 resources.max_input_len
213 ),
214 });
215 }
216
217 let params = DfaParams {
218 input_len,
219 state_count: self.state_count,
220 max_matches: self.max_matches,
221 _pad: 0,
222 };
223 queue.write_buffer(&resources.input_buffer, 0, input);
224 zero_dfa_input_padding(queue, &resources.input_buffer, input.len())?;
225 queue.write_buffer(&resources.params_buffer, 0, bytemuck::bytes_of(¶ms));
226
227 let bind_group = self.create_bind_group(
228 &resources.input_buffer,
229 &resources.match_buffer,
230 &resources.match_count_buffer,
231 &resources.params_buffer,
232 );
233 let mut owned_encoder = command_encoder.is_none().then(|| {
234 self.device
235 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
236 label: Some("vyre dfa dispatch and readback"),
237 })
238 });
239 let encoder = if let Some(encoder) = command_encoder {
240 encoder
241 } else {
242 owned_encoder
243 .as_mut()
244 .expect("owned encoder must be present when command_encoder is omitted")
245 };
246 encoder.clear_buffer(&resources.match_count_buffer, 0, None);
247 {
248 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
249 label: Some("vyre dfa scan pass"),
250 timestamp_writes: None,
251 });
252 pass.set_pipeline(&self.pipeline);
253 pass.set_bind_group(0, &bind_group, &[]);
254 pass.dispatch_workgroups(dispatch_groups(input_len), 1, 1);
255 }
256 encoder.copy_buffer_to_buffer(
257 &resources.match_count_buffer,
258 0,
259 &resources.count_readback,
260 0,
261 4,
262 );
263 let Some(owned_encoder) = owned_encoder else {
264 return Err(Error::Dfa {
265 message: "DFA scan was called with an external command encoder, but this API returns readback matches that are unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred DFA API that returns readback buffers.".to_string(),
266 });
267 };
268 let count_submission = queue.submit(std::iter::once(owned_encoder.finish()));
269
270 let reported = read_one_u32(
271 &self.device,
272 &resources.count_readback,
273 "match count",
274 count_submission,
275 )?;
276 let captured = reported.min(self.max_matches);
277 if captured == 0 {
278 return Ok(Vec::new());
279 }
280 let mut readback_encoder =
281 self.device
282 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
283 label: Some("vyre dfa match readback"),
284 });
285 readback_encoder.copy_buffer_to_buffer(
286 &resources.match_buffer,
287 0,
288 &resources.match_readback,
289 0,
290 buffers::match_buffer_size(captured)?,
291 );
292 let match_submission = queue.submit(std::iter::once(readback_encoder.finish()));
293 let mut matches = read_matches(
294 &self.device,
295 &resources.match_readback,
296 captured,
297 match_submission,
298 )?;
299 matches.sort_unstable();
300 Ok(matches)
301 }
302
303 pub fn scan_shared(&self, input: &[u8]) -> Result<Vec<vyre::Match>> {
308 if !self.compiled_with_cached_device {
309 return Err(Error::Dfa {
310 message: "DFA was compiled with a non-shared GPU device. Fix: compile with vyre::runtime::cached_device() before calling scan_shared(), or call scan() with the original device and queue.".to_string(),
311 });
312 }
313 let (device, queue) = crate::runtime::cached_device()?;
314 self.scan(device, queue, input, None)
315 }
316
317 #[must_use]
319 pub fn state_count(&self) -> u32 {
320 self.state_count
321 }
322
323 #[must_use]
325 pub fn max_matches(&self) -> u32 {
326 self.max_matches
327 }
328
329 #[must_use]
331 pub fn pattern_lengths(&self) -> &[u32] {
332 &self.pattern_lengths
333 }
334
335 pub(crate) fn create_bind_group(
336 &self,
337 input_buffer: &wgpu::Buffer,
338 match_buffer: &wgpu::Buffer,
339 match_count_buffer: &wgpu::Buffer,
340 params_buffer: &wgpu::Buffer,
341 ) -> wgpu::BindGroup {
342 let layout = self.pipeline.get_bind_group_layout(0);
343 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
344 label: Some("vyre dfa bind group"),
345 layout: &layout,
346 entries: &[
347 entry(0, input_buffer),
348 entry(1, &self.transition_buffer),
349 entry(2, &self.accept_buffer),
350 entry(3, match_buffer),
351 entry(4, match_count_buffer),
352 entry(5, params_buffer),
353 entry(6, &self.pattern_length_buffer),
354 ],
355 })
356 }
357}