rusty_regex/
dfa.rs

1use std::collections::{HashMap, HashSet, BTreeSet};
2use crate::nfa::{NFA, TransitionLabel};
3use crate::error::RegexError;
4use crate::ast::PredefClass;
5
6#[derive(Debug)]
7pub struct DFA {
8    pub start: usize,
9    pub states: Vec<DFANode>,
10    pub accept_states: HashSet<usize>,
11}
12
13#[derive(Debug)]
14pub struct DFANode {
15    /// Mapping: input character -> target DFA state
16    pub transitions: HashMap<char, usize>,
17    /// The set of NFA states represented by this DFA node
18    pub nfa_states: HashSet<usize>,
19}
20
21/// Computes the epsilon-closure for a set of NFA states.
22fn epsilon_closure(nfa: &NFA, states: &HashSet<usize>) -> HashSet<usize> {
23    let mut closure = states.clone();
24    let mut stack: Vec<usize> = states.iter().cloned().collect();
25    while let Some(state_id) = stack.pop() {
26        for transition in &nfa.states[state_id].transitions {
27            if let TransitionLabel::Epsilon = transition.label {
28                if closure.insert(transition.target) {
29                    stack.push(transition.target);
30                }
31            }
32        }
33    }
34    closure
35}
36
37/// For a set of NFA states, returns the set reachable on a given character.
38fn move_on_char(nfa: &NFA, states: &HashSet<usize>, c: char) -> HashSet<usize> {
39    let mut result = HashSet::new();
40    for &state_id in states {
41        for transition in &nfa.states[state_id].transitions {
42            // Skip epsilon transitions.
43            if let TransitionLabel::Epsilon = transition.label {
44                continue;
45            }
46            if matches_label(&transition.label, c) {
47                result.insert(transition.target);
48            }
49        }
50    }
51    result
52}
53
54/// Checks if a transition label matches a given input character.
55fn matches_label(label: &TransitionLabel, c: char) -> bool {
56    match label {
57        TransitionLabel::Char(ch) => *ch == c,
58        TransitionLabel::Any => true,
59        TransitionLabel::CharClass(chars) => chars.contains(&c),
60        TransitionLabel::Predefined(predef) => match predef {
61            PredefClass::Digit => c.is_ascii_digit(),
62            PredefClass::Word => c.is_alphanumeric() || c == '_',
63        },
64        _ => false,
65    }
66}
67
68/// Converts the provided NFA into a DFA using subset construction.
69pub fn compile_nfa_to_dfa(nfa: &NFA) -> Result<DFA, RegexError> {
70    let mut dfa_states: Vec<DFANode> = Vec::new();
71    let mut dfa_state_map: HashMap<BTreeSet<usize>, usize> = HashMap::new();
72    let mut unmarked_states: Vec<usize> = Vec::new();
73    
74    // Start with the epsilon-closure of the NFA's start state.
75    let mut start_set = HashSet::new();
76    start_set.insert(nfa.start);
77    start_set = epsilon_closure(nfa, &start_set);
78    let start_key: BTreeSet<usize> = start_set.iter().cloned().collect();
79    dfa_states.push(DFANode { transitions: HashMap::new(), nfa_states: start_set.clone() });
80    dfa_state_map.insert(start_key, 0);
81    unmarked_states.push(0);
82    
83    // Define the alphabet: ASCII characters from 32 to 126.
84    let alphabet: Vec<char> = (32u8..=126).map(|c| c as char).collect();
85    
86    // Process unmarked DFA states.
87    while let Some(dfa_state_id) = unmarked_states.pop() {
88        // Clone the current state's NFA state set to avoid borrowing issues.
89        let current_nfa_states = dfa_states[dfa_state_id].nfa_states.clone();
90        for &symbol in &alphabet {
91            let mut move_result = move_on_char(nfa, &current_nfa_states, symbol);
92            if move_result.is_empty() {
93                continue;
94            }
95            move_result = epsilon_closure(nfa, &move_result);
96            let key: BTreeSet<usize> = move_result.iter().cloned().collect();
97            let target_dfa_state = if let Some(&state_id) = dfa_state_map.get(&key) {
98                state_id
99            } else {
100                let new_state_id = dfa_states.len();
101                dfa_states.push(DFANode { transitions: HashMap::new(), nfa_states: move_result.clone() });
102                dfa_state_map.insert(key, new_state_id);
103                unmarked_states.push(new_state_id);
104                new_state_id
105            };
106            dfa_states[dfa_state_id].transitions.insert(symbol, target_dfa_state);
107        }
108    }
109    
110    // Mark DFA accept states if they include the NFA's accept state.
111    let mut accept_states = HashSet::new();
112    for (id, dfa_state) in dfa_states.iter().enumerate() {
113        if dfa_state.nfa_states.contains(&nfa.accept) {
114            accept_states.insert(id);
115        }
116    }
117    
118    Ok(DFA {
119        start: 0,
120        states: dfa_states,
121        accept_states,
122    })
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::nfa::{compile_ast_to_nfa, NFA};
129    use crate::ast::{Ast, PredefClass};
130    use crate::error::RegexError;
131    
132    #[test]
133    fn test_compile_nfa_to_dfa_literal() -> Result<(), RegexError> {
134        let ast = Ast::Literal('a');
135        let nfa = compile_ast_to_nfa(&ast)?;
136        let dfa = compile_nfa_to_dfa(&nfa)?;
137        
138        // Ensure DFA has states and at least one accept state.
139        assert!(!dfa.states.is_empty());
140        assert!(!dfa.accept_states.is_empty());
141        
142        // For the literal 'a', the DFA should have a transition on 'a' leading to an accept state.
143        let start_state = &dfa.states[dfa.start];
144        if let Some(&target) = start_state.transitions.get(&'a') {
145            assert!(dfa.accept_states.contains(&target));
146        } else {
147            panic!("DFA did not have a transition for 'a'");
148        }
149        Ok(())
150    }
151    
152    #[test]
153    fn test_compile_nfa_to_dfa_dot() -> Result<(), RegexError> {
154        // Test for the dot ('.') which should match any character.
155        let ast = Ast::Dot;
156        let nfa = compile_ast_to_nfa(&ast)?;
157        let dfa = compile_nfa_to_dfa(&nfa)?;
158        // For each printable ASCII character, the DFA should have a valid transition.
159        let start_state = &dfa.states[dfa.start];
160        for c in (32u8..=126).map(|c| c as char) {
161            if let Some(&target) = start_state.transitions.get(&c) {
162                assert!(dfa.accept_states.contains(&target));
163            }
164        }
165        Ok(())
166    }
167    
168    #[test]
169    fn test_compile_nfa_to_dfa_predefined() -> Result<(), RegexError> {
170        // Test for a predefined class (\d), matching only digits.
171        let ast = Ast::PredefinedClass(PredefClass::Digit);
172        let nfa = compile_ast_to_nfa(&ast)?;
173        let dfa = compile_nfa_to_dfa(&nfa)?;
174        let start_state = &dfa.states[dfa.start];
175        // Verify that digits result in transitions that eventually reach an accept state.
176        for c in '0'..='9' {
177            if let Some(&target) = start_state.transitions.get(&c) {
178                assert!(dfa.accept_states.contains(&target));
179            }
180        }
181        Ok(())
182    }
183    
184    #[test]
185    fn test_compile_nfa_to_dfa_character_class() -> Result<(), RegexError> {
186        // Test a character class [abc]
187        let ast = Ast::CharacterClass(vec!['a', 'b', 'c']);
188        let nfa = compile_ast_to_nfa(&ast)?;
189        let dfa = compile_nfa_to_dfa(&nfa)?;
190        let start_state = &dfa.states[dfa.start];
191        for c in ['a', 'b', 'c'] {
192            if let Some(&target) = start_state.transitions.get(&c) {
193                assert!(dfa.accept_states.contains(&target));
194            }
195        }
196        Ok(())
197    }
198}