Skip to main content

vyre_reference/dual_impls/scan/dfa/
reference.rs

1use crate::{dual_impls::common, workgroup::Memory};
2use vyre_primitives::PatternMatchDfa;
3
4const HEADER_LEN: usize = 16;
5const MAGIC: &[u8; 4] = b"VDFA";
6
7impl common::ReferenceEvaluator for PatternMatchDfa {
8    fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, common::EvalError> {
9        let haystack = common::one_input(inputs, "scan_dfa")?;
10        let dfa = ParsedDfa::parse(&self.dfa)?;
11        let mut state = dfa.start;
12        let mut offsets = Vec::new();
13        for (offset, byte) in haystack.iter().copied().enumerate() {
14            state = dfa.step(state, byte)?;
15            if dfa.accepts[state] {
16                offsets.push(u32::try_from(offset).map_err(|_| {
17                    common::EvalError::new(
18                        "primitive `scan_dfa` offset exceeds u32. Fix: split haystacks before 4 GiB.",
19                    )
20                })?);
21            }
22        }
23        Ok(common::write_u32s(offsets))
24    }
25}
26
27struct ParsedDfa {
28    start: usize,
29    accepts: Vec<bool>,
30    transitions: Vec<u32>,
31}
32
33impl ParsedDfa {
34    fn parse(bytes: &[u8]) -> Result<Self, common::EvalError> {
35        if bytes.len() < HEADER_LEN || &bytes[..4] != MAGIC {
36            return Err(common::EvalError::new(
37                "primitive `scan_dfa` expected VDFA header. Fix: encode magic, state_count, start, and accept_count.",
38            ));
39        }
40        let state_count = read_u32_at(bytes, 4)? as usize;
41        let start = read_u32_at(bytes, 8)? as usize;
42        let accept_count = read_u32_at(bytes, 12)? as usize;
43        if state_count == 0 || start >= state_count {
44            return Err(common::EvalError::new(
45                "primitive `scan_dfa` has invalid state count/start. Fix: provide at least one state and a valid start state.",
46            ));
47        }
48        let accept_bytes = accept_count.checked_mul(4).ok_or_else(|| {
49            common::EvalError::new(
50                "primitive `scan_dfa` accept table size overflow. Fix: bound DFA state metadata.",
51            )
52        })?;
53        let transition_start = HEADER_LEN + accept_bytes;
54        let transition_words = state_count.checked_mul(256).ok_or_else(|| {
55            common::EvalError::new(
56                "primitive `scan_dfa` transition table size overflow. Fix: bound DFA state count.",
57            )
58        })?;
59        let transition_bytes = transition_words.checked_mul(4).ok_or_else(|| {
60            common::EvalError::new(
61                "primitive `scan_dfa` transition byte size overflow. Fix: bound DFA state count.",
62            )
63        })?;
64        if bytes.len() != transition_start + transition_bytes {
65            return Err(common::EvalError::new(format!(
66                "primitive `scan_dfa` byte length mismatch: got {}, expected {}. Fix: encode accept states followed by state_count*256 u32 transitions.",
67                bytes.len(),
68                transition_start + transition_bytes
69            )));
70        }
71        let mut accepts = vec![false; state_count];
72        for accept in 0..accept_count {
73            let state = read_u32_at(bytes, HEADER_LEN + accept * 4)? as usize;
74            if state >= state_count {
75                return Err(common::EvalError::new(
76                    "primitive `scan_dfa` accept state is out of range. Fix: keep accept states below state_count.",
77                ));
78            }
79            accepts[state] = true;
80        }
81        let transitions = common::u32_words(&bytes[transition_start..], "scan_dfa")?;
82        Ok(Self {
83            start,
84            accepts,
85            transitions,
86        })
87    }
88
89    fn step(&self, state: usize, byte: u8) -> Result<usize, common::EvalError> {
90        let offset = state * 256 + usize::from(byte);
91        let next = self.transitions[offset] as usize;
92        if next >= self.accepts.len() {
93            Err(common::EvalError::new(
94                "primitive `scan_dfa` transition targets an out-of-range state. Fix: validate every transition target.",
95            ))
96        } else {
97            Ok(next)
98        }
99    }
100}
101
102fn read_u32_at(bytes: &[u8], offset: usize) -> Result<u32, common::EvalError> {
103    if offset + 4 > bytes.len() {
104        return Err(common::EvalError::new(
105            "primitive `scan_dfa` truncated u32 field. Fix: encode all header fields.",
106        ));
107    }
108    Ok(u32::from_le_bytes([
109        bytes[offset],
110        bytes[offset + 1],
111        bytes[offset + 2],
112        bytes[offset + 3],
113    ]))
114}