Skip to main content

vyre_std/pattern/
dfa_pack.rs

1//! Transition-table compression for GPU dispatch.
2//!
3//! The pack formats trade memory footprint for scan speed. Dense is the
4//! fastest but largest; EquivClass collapses redundant byte columns to
5//! shrink the table when the effective alphabet is small.
6
7use super::types::{Dfa, DfaPackFormat, PackedDfa, INVALID_STATE};
8
9/// Pack a DFA into GPU-uploadable bytes using the selected format.
10///
11/// # Examples
12///
13/// ```
14/// use vyre_std::pattern::{regex_to_nfa::regex_to_nfa, nfa_to_dfa::nfa_to_dfa, dfa_minimize::dfa_minimize, dfa_pack::dfa_pack, types::DfaPackFormat};
15///
16/// let nfa = regex_to_nfa("foo|bar").unwrap();
17/// let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
18/// let packed = dfa_pack(&dfa, DfaPackFormat::Dense);
19/// assert!(!packed.bytes.is_empty());
20/// ```
21#[must_use]
22#[inline]
23pub fn dfa_pack(dfa: &Dfa, format: DfaPackFormat) -> PackedDfa {
24    match format {
25        DfaPackFormat::Dense => pack_dense(dfa),
26        DfaPackFormat::EquivClass => pack_equiv_class(dfa),
27    }
28}
29
30fn pack_dense(dfa: &Dfa) -> PackedDfa {
31    // Layout: [format tag: u32][state_count: u32][start: u32][accept bitmap]
32    //         [transitions: state_count × 256 × u32]
33    let mut bytes = Vec::new();
34    bytes.extend_from_slice(&0u32.to_le_bytes()); // format tag 0 = Dense
35    bytes.extend_from_slice(&dfa.state_count.to_le_bytes());
36    bytes.extend_from_slice(&dfa.start.to_le_bytes());
37    write_accept(&mut bytes, &dfa.accept);
38    for &t in &dfa.transitions {
39        let word: u32 = t;
40        bytes.extend_from_slice(&word.to_le_bytes());
41    }
42    PackedDfa {
43        format: DfaPackFormat::Dense,
44        state_count: dfa.state_count,
45        start: dfa.start,
46        bytes,
47    }
48}
49
50fn pack_equiv_class(dfa: &Dfa) -> PackedDfa {
51    // Build byte → class table by column-equivalence.
52    let state_count = dfa.state_count as usize;
53    let mut columns: Vec<Vec<u32>> = Vec::with_capacity(256);
54    for byte in 0u8..=255 {
55        let col: Vec<u32> = (0..state_count)
56            .map(|s| dfa.transitions[s * 256 + byte as usize])
57            .collect();
58        columns.push(col);
59    }
60    let mut classes: Vec<u8> = Vec::with_capacity(256);
61    let mut class_representatives: Vec<Vec<u32>> = Vec::new();
62    for col in &columns {
63        let mut found = None;
64        for (idx, rep) in class_representatives.iter().enumerate() {
65            if rep == col {
66                found = Some(idx);
67                break;
68            }
69        }
70        match found {
71            Some(idx) => classes.push(idx as u8),
72            None => {
73                classes.push(class_representatives.len() as u8);
74                class_representatives.push(col.clone());
75            }
76        }
77    }
78    let num_classes = class_representatives.len() as u32;
79
80    // Layout: [format tag: u32 = 1][state_count: u32][start: u32][num_classes: u32]
81    //         [class table: 256 × u8 padded to u32]
82    //         [accept bitmap]
83    //         [transitions: state_count × num_classes × u32]
84    let mut bytes = Vec::new();
85    bytes.extend_from_slice(&1u32.to_le_bytes()); // format tag 1 = EquivClass
86    bytes.extend_from_slice(&dfa.state_count.to_le_bytes());
87    bytes.extend_from_slice(&dfa.start.to_le_bytes());
88    bytes.extend_from_slice(&num_classes.to_le_bytes());
89    for &c in &classes {
90        bytes.push(c);
91    }
92    // Pad class table to 4-byte alignment.
93    while bytes.len() % 4 != 0 {
94        bytes.push(0);
95    }
96    write_accept(&mut bytes, &dfa.accept);
97    // class_representatives is indexed [class][state]; the iteration
98    // order (state outer, class inner) transposes it into the packed
99    // layout. The needless_range_loop lint would prefer iterator
100    // access but the transpose semantics are clearer as indexed.
101    #[allow(clippy::needless_range_loop)]
102    for state in 0..state_count {
103        for class in 0..num_classes as usize {
104            bytes.extend_from_slice(&class_representatives[class][state].to_le_bytes());
105        }
106    }
107    PackedDfa {
108        format: DfaPackFormat::EquivClass,
109        state_count: dfa.state_count,
110        start: dfa.start,
111        bytes,
112    }
113}
114
115fn write_accept(bytes: &mut Vec<u8>, accept: &[bool]) {
116    let words = accept.len().div_ceil(32);
117    bytes.extend_from_slice(&(words as u32).to_le_bytes());
118    let mut word: u32 = 0;
119    let mut bit = 0;
120    for &a in accept {
121        if a {
122            word |= 1 << bit;
123        }
124        bit += 1;
125        if bit == 32 {
126            bytes.extend_from_slice(&word.to_le_bytes());
127            word = 0;
128            bit = 0;
129        }
130    }
131    if bit != 0 {
132        bytes.extend_from_slice(&word.to_le_bytes());
133    }
134}
135
136/// Unpack a [`PackedDfa`] back into a [`Dfa`]. Used by tests and by
137/// consumers that need to verify a packed buffer round-trips.
138///
139/// # Errors
140///
141/// Returns `None` when the buffer is malformed or the format tag does not
142/// match any known encoding.
143#[must_use]
144#[inline]
145pub fn dfa_unpack(packed: &PackedDfa) -> Option<Dfa> {
146    let bytes = &packed.bytes;
147    let tag = u32::from_le_bytes(bytes.get(0..4)?.try_into().ok()?);
148    match (tag, packed.format) {
149        (0, DfaPackFormat::Dense) => unpack_dense(bytes),
150        (1, DfaPackFormat::EquivClass) => unpack_equiv_class(bytes),
151        _ => None,
152    }
153}
154
155fn unpack_dense(bytes: &[u8]) -> Option<Dfa> {
156    let state_count = u32::from_le_bytes(bytes.get(4..8)?.try_into().ok()?);
157    let start = u32::from_le_bytes(bytes.get(8..12)?.try_into().ok()?);
158    let accept_words = u32::from_le_bytes(bytes.get(12..16)?.try_into().ok()?) as usize;
159    let accept_start = 16;
160    let accept_end = accept_start + accept_words * 4;
161    let accept = read_accept(&bytes[accept_start..accept_end], state_count as usize);
162    let trans_start = accept_end;
163    let trans_end = trans_start + (state_count as usize) * 256 * 4;
164    let transitions: Vec<u32> = bytes[trans_start..trans_end]
165        .chunks_exact(4)
166        .map(|c| u32::from_le_bytes(c.try_into().unwrap_or([0; 4])))
167        .collect();
168    Some(Dfa {
169        state_count,
170        transitions,
171        start,
172        accept,
173    })
174}
175
176fn unpack_equiv_class(bytes: &[u8]) -> Option<Dfa> {
177    let state_count = u32::from_le_bytes(bytes.get(4..8)?.try_into().ok()?);
178    let start = u32::from_le_bytes(bytes.get(8..12)?.try_into().ok()?);
179    let num_classes = u32::from_le_bytes(bytes.get(12..16)?.try_into().ok()?) as usize;
180    let class_start = 16;
181    let class_end = class_start + 256;
182    let classes = &bytes[class_start..class_end];
183    let aligned_end = (class_end + 3) & !3;
184    let accept_words =
185        u32::from_le_bytes(bytes.get(aligned_end..aligned_end + 4)?.try_into().ok()?) as usize;
186    let accept_data_start = aligned_end + 4;
187    let accept_data_end = accept_data_start + accept_words * 4;
188    let accept = read_accept(
189        &bytes[accept_data_start..accept_data_end],
190        state_count as usize,
191    );
192
193    let trans_start = accept_data_end;
194    let trans_count = (state_count as usize) * num_classes;
195    let mut class_trans: Vec<u32> = Vec::with_capacity(trans_count);
196    for i in 0..trans_count {
197        let off = trans_start + i * 4;
198        class_trans.push(u32::from_le_bytes(
199            bytes.get(off..off + 4)?.try_into().ok()?,
200        ));
201    }
202
203    let mut transitions = vec![INVALID_STATE; (state_count as usize) * 256];
204    for state in 0..state_count as usize {
205        for byte in 0u8..=255 {
206            let class = classes[byte as usize] as usize;
207            transitions[state * 256 + byte as usize] = class_trans[state * num_classes + class];
208        }
209    }
210    Some(Dfa {
211        state_count,
212        transitions,
213        start,
214        accept,
215    })
216}
217
218fn read_accept(bytes: &[u8], state_count: usize) -> Vec<bool> {
219    let mut accept = Vec::with_capacity(state_count);
220    let mut idx = 0;
221    for chunk in bytes.chunks_exact(4) {
222        let word = u32::from_le_bytes(chunk.try_into().unwrap_or([0; 4]));
223        for bit in 0..32 {
224            if idx >= state_count {
225                break;
226            }
227            accept.push((word >> bit) & 1 == 1);
228            idx += 1;
229        }
230    }
231    accept.truncate(state_count);
232    while accept.len() < state_count {
233        accept.push(false);
234    }
235    accept
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::pattern::{
242        dfa_minimize::dfa_minimize, nfa_to_dfa::nfa_to_dfa, regex_to_nfa::regex_to_nfa,
243    };
244
245    fn roundtrip(regex: &str, format: DfaPackFormat) {
246        let nfa = regex_to_nfa(regex).unwrap();
247        let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
248        let packed = dfa_pack(&dfa, format);
249        let unpacked = dfa_unpack(&packed).expect("unpack");
250        assert_eq!(unpacked.state_count, dfa.state_count, "regex `{regex}`");
251        assert_eq!(unpacked.start, dfa.start);
252        assert_eq!(unpacked.accept, dfa.accept);
253        assert_eq!(unpacked.transitions, dfa.transitions);
254    }
255
256    #[test]
257    fn dense_roundtrip_literal() {
258        roundtrip("hello", DfaPackFormat::Dense);
259    }
260
261    #[test]
262    fn dense_roundtrip_alternation() {
263        roundtrip("foo|bar|baz", DfaPackFormat::Dense);
264    }
265
266    #[test]
267    fn dense_roundtrip_kleene() {
268        roundtrip("a*b+c?", DfaPackFormat::Dense);
269    }
270
271    #[test]
272    fn equiv_class_roundtrip_literal() {
273        roundtrip("hello", DfaPackFormat::EquivClass);
274    }
275
276    #[test]
277    fn equiv_class_roundtrip_char_class() {
278        roundtrip("[a-z]+", DfaPackFormat::EquivClass);
279    }
280
281    #[test]
282    fn equiv_class_fewer_bytes_than_dense_for_small_alphabet() {
283        let nfa = regex_to_nfa("abc").unwrap();
284        let dfa = dfa_minimize(&nfa_to_dfa(&nfa).unwrap());
285        let dense = dfa_pack(&dfa, DfaPackFormat::Dense);
286        let equiv = dfa_pack(&dfa, DfaPackFormat::EquivClass);
287        assert!(
288            equiv.bytes.len() < dense.bytes.len(),
289            "equiv-class must be smaller for narrow alphabets: dense={} equiv={}",
290            dense.bytes.len(),
291            equiv.bytes.len()
292        );
293    }
294}