Skip to main content

vyre_wgpu/engine/dfa/
mod.rs

1//! Host-side DFA workflow dispatcher.
2//!
3//! NOTE: This is NOT an IR op domain. It accepts runtime DFA tables and input
4//! bytes, compiles and owns GPU pipelines and buffers, dispatches scanning
5//! kernels, performs deterministic readback sorting, and returns
6//! `vyre::Match` values. The IR-side match domain lives under
7//! `vyre::ops::match_ops`; those modules produce `Program` values that go
8//! through validate and lower.
9
10mod buffers;
11mod input;
12mod readback;
13mod resources;
14mod transitions;
15
16use self::buffers::{entry, storage_buffer};
17use self::input::{checked_input_len, dispatch_groups, DfaParams};
18use self::readback::{read_matches, read_one_u32};
19use self::resources::ScanResources;
20use self::transitions::{build_accept_map, compile_pipeline, validate_compile_inputs};
21use std::sync::Mutex;
22use vyre::error::{Error, Result};
23
24pub(crate) const BYTE_CLASSES: usize = 256;
25pub(crate) const SENTINEL_NO_ACCEPT: u32 = 0xFFFF_FFFF;
26/// Default maximum matches captured by one GPU DFA scan.
27pub const DEFAULT_MAX_MATCHES: u32 = 65_536;
28
29/// Maximum matches a single GPU DFA scan may allocate/read back.
30pub const MAX_DFA_MATCHES: u32 = 1_000_000;
31
32/// A GPU-compiled DFA scanner.
33#[non_exhaustive]
34pub struct GpuDfa {
35    device: wgpu::Device,
36    compiled_with_cached_device: bool,
37    pipeline: wgpu::ComputePipeline,
38    transition_buffer: wgpu::Buffer,
39    accept_buffer: wgpu::Buffer,
40    pattern_length_buffer: wgpu::Buffer,
41    pattern_lengths: Vec<u32>,
42    state_count: u32,
43    max_matches: u32,
44    resources: Mutex<Vec<ScanResources>>,
45}
46
47impl std::fmt::Debug for GpuDfa {
48    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        formatter
50            .debug_struct("GpuDfa")
51            .field("state_count", &self.state_count)
52            .field("pattern_count", &self.pattern_lengths.len())
53            .field("max_matches", &self.max_matches)
54            .finish_non_exhaustive()
55    }
56}
57
58fn zero_dfa_input_padding(
59    queue: &wgpu::Queue,
60    buffer: &wgpu::Buffer,
61    written: usize,
62) -> Result<()> {
63    let padding_len = written.next_multiple_of(4) - written;
64    if padding_len == 0 {
65        return Ok(());
66    }
67    let offset = u64::try_from(written).map_err(|source| Error::Dfa {
68        message: format!(
69            "DFA input byte count {written} cannot fit u64: {source}. Fix: scan chunks smaller than 4 GiB."
70        ),
71    })?;
72    let padding = [0u8; 4];
73    queue.write_buffer(buffer, offset, &padding[..padding_len]);
74    Ok(())
75}
76
77impl GpuDfa {
78    /// Internal pipeline helper. User-facing DFA construction lives in
79    /// `std::pattern::aho_corasick_build`.
80    ///
81    /// # Errors
82    /// Returns [`Error::Dfa`] when the table, output links, or bindings are invalid.
83    pub fn compile(
84        device: &wgpu::Device,
85        transitions: &[u32],
86        state_count: usize,
87        accept_states: &[(u32, u32)],
88        output_links: &[u32],
89        pattern_lengths: &[u32],
90    ) -> Result<Self> {
91        Self::compile_with_max_matches(
92            device,
93            transitions,
94            state_count,
95            accept_states,
96            output_links,
97            pattern_lengths,
98            DEFAULT_MAX_MATCHES,
99        )
100    }
101
102    /// Internal pipeline helper. User-facing DFA construction lives in
103    /// `std::pattern::aho_corasick_build`.
104    ///
105    /// # Errors
106    /// Returns [`Error::Dfa`] when validation of the DFA or GPU binding sizes fails.
107    pub fn compile_with_max_matches(
108        device: &wgpu::Device,
109        transitions: &[u32],
110        state_count: usize,
111        accept_states: &[(u32, u32)],
112        output_links: &[u32],
113        pattern_lengths: &[u32],
114        max_matches: u32,
115    ) -> Result<Self> {
116        validate_compile_inputs(
117            device,
118            transitions,
119            state_count,
120            accept_states,
121            output_links,
122            pattern_lengths,
123            max_matches,
124        )?;
125
126        let accept_map = build_accept_map(state_count, accept_states)?;
127        let pipeline = compile_pipeline(device)?;
128        let transition_buffer = storage_buffer(device, "vyre dfa transitions", transitions);
129        let accept_buffer = storage_buffer(device, "vyre dfa accept map", &accept_map);
130        let pattern_length_buffer =
131            storage_buffer(device, "vyre dfa pattern lengths", pattern_lengths);
132
133        Ok(Self {
134            device: device.clone(),
135            compiled_with_cached_device: crate::runtime::device::is_cached_device(device),
136            pipeline,
137            transition_buffer,
138            accept_buffer,
139            pattern_length_buffer,
140            pattern_lengths: pattern_lengths.to_vec(),
141            state_count: u32::try_from(state_count).map_err(|source| Error::Dfa {
142                message: format!(
143                    "DFA state_count {state_count} cannot fit u32: {source}. Fix: split the automaton or reduce states."
144                ),
145            })?,
146            max_matches,
147            resources: Mutex::new(Vec::new()),
148        })
149    }
150
151    /// Scan input bytes on the GPU and return captured matches.
152    ///
153    /// # Errors
154    /// Returns [`Error::Dfa`] if the input, device, queue, or readback is invalid.
155    pub fn scan(
156        &self,
157        device: &wgpu::Device,
158        queue: &wgpu::Queue,
159        input: &[u8],
160        command_encoder: Option<&mut wgpu::CommandEncoder>,
161    ) -> Result<Vec<vyre::Match>> {
162        if *device != self.device {
163            return Err(Error::Dfa {
164                message: "DFA scan device differs from compile device. Fix: scan with the same wgpu::Device and matching Queue used to compile the DFA.".to_string(),
165            });
166        }
167        if input.is_empty() {
168            return Ok(Vec::new());
169        }
170        let input_len = checked_input_len(input)?;
171        let mut resources = self.acquire_scan_resources(input_len)?;
172        let result =
173            self.scan_with_resources(queue, input, input_len, &mut resources, command_encoder);
174        self.release_scan_resources(resources)?;
175        result
176    }
177
178    fn acquire_scan_resources(&self, input_len: u32) -> Result<ScanResources> {
179        let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
180            message: format!("DFA scan resources mutex is poisoned: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
181        })?;
182        if let Some(index) = pool
183            .iter()
184            .position(|resources| resources.max_input_len >= input_len)
185        {
186            return Ok(pool.swap_remove(index));
187        }
188        drop(pool);
189        ScanResources::new(&self.device, input_len, self.max_matches)
190    }
191
192    fn release_scan_resources(&self, resources: ScanResources) -> Result<()> {
193        let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
194            message: format!("DFA scan resources mutex is poisoned while releasing resources: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
195        })?;
196        pool.push(resources);
197        Ok(())
198    }
199
200    pub(crate) fn scan_with_resources(
201        &self,
202        queue: &wgpu::Queue,
203        input: &[u8],
204        input_len: u32,
205        resources: &mut ScanResources,
206        command_encoder: Option<&mut wgpu::CommandEncoder>,
207    ) -> Result<Vec<vyre::Match>> {
208        if input_len > resources.max_input_len {
209            return Err(Error::Dfa {
210                message: format!(
211                    "DFA input length {input_len} exceeds ScanResources capacity {}. Fix: create larger ScanResources.",
212                    resources.max_input_len
213                ),
214            });
215        }
216
217        let params = DfaParams {
218            input_len,
219            state_count: self.state_count,
220            max_matches: self.max_matches,
221            _pad: 0,
222        };
223        queue.write_buffer(&resources.input_buffer, 0, input);
224        zero_dfa_input_padding(queue, &resources.input_buffer, input.len())?;
225        queue.write_buffer(&resources.params_buffer, 0, bytemuck::bytes_of(&params));
226
227        let bind_group = self.create_bind_group(
228            &resources.input_buffer,
229            &resources.match_buffer,
230            &resources.match_count_buffer,
231            &resources.params_buffer,
232        );
233        let mut owned_encoder = command_encoder.is_none().then(|| {
234            self.device
235                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
236                    label: Some("vyre dfa dispatch and readback"),
237                })
238        });
239        let encoder = if let Some(encoder) = command_encoder {
240            encoder
241        } else {
242            owned_encoder
243                .as_mut()
244                .expect("owned encoder must be present when command_encoder is omitted")
245        };
246        encoder.clear_buffer(&resources.match_count_buffer, 0, None);
247        {
248            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
249                label: Some("vyre dfa scan pass"),
250                timestamp_writes: None,
251            });
252            pass.set_pipeline(&self.pipeline);
253            pass.set_bind_group(0, &bind_group, &[]);
254            pass.dispatch_workgroups(dispatch_groups(input_len), 1, 1);
255        }
256        encoder.copy_buffer_to_buffer(
257            &resources.match_count_buffer,
258            0,
259            &resources.count_readback,
260            0,
261            4,
262        );
263        let Some(owned_encoder) = owned_encoder else {
264            return Err(Error::Dfa {
265                message: "DFA scan was called with an external command encoder, but this API returns readback matches that are unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred DFA API that returns readback buffers.".to_string(),
266            });
267        };
268        let count_submission = queue.submit(std::iter::once(owned_encoder.finish()));
269
270        let reported = read_one_u32(
271            &self.device,
272            &resources.count_readback,
273            "match count",
274            count_submission,
275        )?;
276        let captured = reported.min(self.max_matches);
277        if captured == 0 {
278            return Ok(Vec::new());
279        }
280        let mut readback_encoder =
281            self.device
282                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
283                    label: Some("vyre dfa match readback"),
284                });
285        readback_encoder.copy_buffer_to_buffer(
286            &resources.match_buffer,
287            0,
288            &resources.match_readback,
289            0,
290            buffers::match_buffer_size(captured)?,
291        );
292        let match_submission = queue.submit(std::iter::once(readback_encoder.finish()));
293        let mut matches = read_matches(
294            &self.device,
295            &resources.match_readback,
296            captured,
297            match_submission,
298        )?;
299        matches.sort_unstable();
300        Ok(matches)
301    }
302
303    /// Scan input bytes using the shared vyre device.
304    ///
305    /// # Errors
306    /// Returns [`Error::Dfa`] if this DFA was not compiled on the shared runtime device.
307    pub fn scan_shared(&self, input: &[u8]) -> Result<Vec<vyre::Match>> {
308        if !self.compiled_with_cached_device {
309            return Err(Error::Dfa {
310                message: "DFA was compiled with a non-shared GPU device. Fix: compile with vyre::runtime::cached_device() before calling scan_shared(), or call scan() with the original device and queue.".to_string(),
311            });
312        }
313        let (device, queue) = crate::runtime::cached_device()?;
314        self.scan(device, queue, input, None)
315    }
316
317    /// Number of DFA states in the compiled scanner.
318    #[must_use]
319    pub fn state_count(&self) -> u32 {
320        self.state_count
321    }
322
323    /// Maximum number of matches this scanner captures.
324    #[must_use]
325    pub fn max_matches(&self) -> u32 {
326        self.max_matches
327    }
328
329    /// Pattern lengths supplied at compile time.
330    #[must_use]
331    pub fn pattern_lengths(&self) -> &[u32] {
332        &self.pattern_lengths
333    }
334
335    pub(crate) fn create_bind_group(
336        &self,
337        input_buffer: &wgpu::Buffer,
338        match_buffer: &wgpu::Buffer,
339        match_count_buffer: &wgpu::Buffer,
340        params_buffer: &wgpu::Buffer,
341    ) -> wgpu::BindGroup {
342        let layout = self.pipeline.get_bind_group_layout(0);
343        self.device.create_bind_group(&wgpu::BindGroupDescriptor {
344            label: Some("vyre dfa bind group"),
345            layout: &layout,
346            entries: &[
347                entry(0, input_buffer),
348                entry(1, &self.transition_buffer),
349                entry(2, &self.accept_buffer),
350                entry(3, match_buffer),
351                entry(4, match_count_buffer),
352                entry(5, params_buffer),
353                entry(6, &self.pattern_length_buffer),
354            ],
355        })
356    }
357}