vyre_reference/dual_impls/scan/dfa/
reference.rs1use crate::{dual_impls::common, workgroup::Memory};
2use vyre_primitives::PatternMatchDfa;
3
4const HEADER_LEN: usize = 16;
5const MAGIC: &[u8; 4] = b"VDFA";
6
7impl common::ReferenceEvaluator for PatternMatchDfa {
8 fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, common::EvalError> {
9 let haystack = common::one_input(inputs, "scan_dfa")?;
10 let dfa = ParsedDfa::parse(&self.dfa)?;
11 let mut state = dfa.start;
12 let mut offsets = Vec::new();
13 for (offset, byte) in haystack.iter().copied().enumerate() {
14 state = dfa.step(state, byte)?;
15 if dfa.accepts[state] {
16 offsets.push(u32::try_from(offset).map_err(|_| {
17 common::EvalError::new(
18 "primitive `scan_dfa` offset exceeds u32. Fix: split haystacks before 4 GiB.",
19 )
20 })?);
21 }
22 }
23 Ok(common::write_u32s(offsets))
24 }
25}
26
27struct ParsedDfa {
28 start: usize,
29 accepts: Vec<bool>,
30 transitions: Vec<u32>,
31}
32
33impl ParsedDfa {
34 fn parse(bytes: &[u8]) -> Result<Self, common::EvalError> {
35 if bytes.len() < HEADER_LEN || &bytes[..4] != MAGIC {
36 return Err(common::EvalError::new(
37 "primitive `scan_dfa` expected VDFA header. Fix: encode magic, state_count, start, and accept_count.",
38 ));
39 }
40 let state_count = read_u32_at(bytes, 4)? as usize;
41 let start = read_u32_at(bytes, 8)? as usize;
42 let accept_count = read_u32_at(bytes, 12)? as usize;
43 if state_count == 0 || start >= state_count {
44 return Err(common::EvalError::new(
45 "primitive `scan_dfa` has invalid state count/start. Fix: provide at least one state and a valid start state.",
46 ));
47 }
48 let accept_bytes = accept_count.checked_mul(4).ok_or_else(|| {
49 common::EvalError::new(
50 "primitive `scan_dfa` accept table size overflow. Fix: bound DFA state metadata.",
51 )
52 })?;
53 let transition_start = HEADER_LEN + accept_bytes;
54 let transition_words = state_count.checked_mul(256).ok_or_else(|| {
55 common::EvalError::new(
56 "primitive `scan_dfa` transition table size overflow. Fix: bound DFA state count.",
57 )
58 })?;
59 let transition_bytes = transition_words.checked_mul(4).ok_or_else(|| {
60 common::EvalError::new(
61 "primitive `scan_dfa` transition byte size overflow. Fix: bound DFA state count.",
62 )
63 })?;
64 if bytes.len() != transition_start + transition_bytes {
65 return Err(common::EvalError::new(format!(
66 "primitive `scan_dfa` byte length mismatch: got {}, expected {}. Fix: encode accept states followed by state_count*256 u32 transitions.",
67 bytes.len(),
68 transition_start + transition_bytes
69 )));
70 }
71 let mut accepts = vec![false; state_count];
72 for accept in 0..accept_count {
73 let state = read_u32_at(bytes, HEADER_LEN + accept * 4)? as usize;
74 if state >= state_count {
75 return Err(common::EvalError::new(
76 "primitive `scan_dfa` accept state is out of range. Fix: keep accept states below state_count.",
77 ));
78 }
79 accepts[state] = true;
80 }
81 let transitions = common::u32_words(&bytes[transition_start..], "scan_dfa")?;
82 Ok(Self {
83 start,
84 accepts,
85 transitions,
86 })
87 }
88
89 fn step(&self, state: usize, byte: u8) -> Result<usize, common::EvalError> {
90 let offset = state * 256 + usize::from(byte);
91 let next = self.transitions[offset] as usize;
92 if next >= self.accepts.len() {
93 Err(common::EvalError::new(
94 "primitive `scan_dfa` transition targets an out-of-range state. Fix: validate every transition target.",
95 ))
96 } else {
97 Ok(next)
98 }
99 }
100}
101
102fn read_u32_at(bytes: &[u8], offset: usize) -> Result<u32, common::EvalError> {
103 if offset + 4 > bytes.len() {
104 return Err(common::EvalError::new(
105 "primitive `scan_dfa` truncated u32 field. Fix: encode all header fields.",
106 ));
107 }
108 Ok(u32::from_le_bytes([
109 bytes[offset],
110 bytes[offset + 1],
111 bytes[offset + 2],
112 bytes[offset + 3],
113 ]))
114}