sapper/recognizer/
nfa.rs

1use std::collections::HashSet;
2use std::u64;
3use self::CharacterClass::{Ascii, ValidChars, InvalidChars};
4
5#[cfg(test)] use std::collections::BTreeSet;
6
7#[derive(PartialEq, Eq, Clone)]
8pub struct CharSet {
9    low_mask: u64,
10    high_mask: u64,
11    non_ascii: HashSet<char>
12}
13
14impl CharSet {
15    pub fn new() -> CharSet {
16        CharSet{ low_mask: 0, high_mask: 0, non_ascii: HashSet::new() }
17    }
18
19    pub fn insert(&mut self, char: char) {
20        let val = char as u32 - 1;
21
22        if val > 127 {
23            self.non_ascii.insert(char);
24        } else if val > 63 {
25            let bit = 1 << val - 64;
26            self.high_mask = self.high_mask | bit;
27        } else {
28            let bit = 1 << val;
29            self.low_mask = self.low_mask | bit;
30        }
31    }
32
33    pub fn contains(&self, char: char) -> bool {
34        let val = char as u32 - 1;
35
36        if val > 127 {
37            self.non_ascii.contains(&char)
38        } else if val > 63 {
39            let bit = 1 << val - 64;
40            self.high_mask & bit != 0
41        } else {
42            let bit = 1 << val;
43            self.low_mask & bit != 0
44        }
45    }
46}
47
48#[derive(PartialEq, Eq, Clone)]
49pub enum CharacterClass {
50    Ascii(u64, u64, bool),
51    ValidChars(CharSet),
52    InvalidChars(CharSet)
53}
54
55impl CharacterClass {
56    pub fn any() -> CharacterClass {
57        Ascii(u64::MAX, u64::MAX, true)
58    }
59
60    pub fn valid(string: &str) -> CharacterClass {
61        ValidChars(CharacterClass::str_to_set(string))
62    }
63
64    pub fn invalid(string: &str) -> CharacterClass {
65        InvalidChars(CharacterClass::str_to_set(string))
66    }
67
68    pub fn valid_char(char: char) -> CharacterClass {
69        let val = char as u32 - 1;
70
71        if val > 127 {
72            ValidChars(CharacterClass::char_to_set(char))
73        } else if val > 63 {
74            Ascii(1 << val - 64, 0, false)
75        } else {
76            Ascii(0, 1 << val, false)
77        }
78    }
79
80    pub fn invalid_char(char: char) -> CharacterClass {
81        let val = char as u32 - 1;
82
83        if val > 127 {
84            InvalidChars(CharacterClass::char_to_set(char))
85        } else if val > 63 {
86            Ascii(u64::MAX ^ (1 << val - 64), u64::MAX, true)
87        } else {
88            Ascii(u64::MAX, u64::MAX ^ (1 << val), true)
89        }
90    }
91
92
93    pub fn matches(&self, char: char) -> bool {
94        match *self {
95            ValidChars(ref valid) => valid.contains(char),
96            InvalidChars(ref invalid) => !invalid.contains(char),
97            Ascii(high, low, unicode) => {
98                let val = char as u32 - 1;
99                if val > 127 {
100                    unicode
101                } else if val > 63 {
102                    high & (1 << (val - 64)) != 0
103                } else {
104                    low & (1 << val) != 0
105                }
106            }
107        }
108    }
109
110    fn char_to_set(char: char) -> CharSet {
111        let mut set = CharSet::new();
112        set.insert(char);
113        set
114    }
115
116    fn str_to_set(string: &str) -> CharSet {
117        let mut set = CharSet::new();
118        for char in string.chars() {
119            set.insert(char);
120        }
121        set
122    }
123}
124
125#[derive(Clone)]
126struct Thread {
127    state: usize,
128    captures: Vec<(usize, usize)>,
129    capture_begin: Option<usize>
130}
131
132impl Thread {
133    pub fn new() -> Thread {
134        Thread{ state: 0, captures: Vec::new(), capture_begin: None }
135    }
136
137    #[inline]
138    pub fn start_capture(&mut self, start: usize) {
139        self.capture_begin = Some(start);
140    }
141
142    #[inline]
143    pub fn end_capture(&mut self, end: usize) {
144        self.captures.push((self.capture_begin.unwrap(), end));
145        self.capture_begin = None;
146    }
147
148    pub fn extract<'a>(&self, source: &'a str) -> Vec<&'a str> {
149        self.captures.iter().map(|&(begin, end)| &source[begin..end]).collect()
150    }
151}
152
153#[derive(Clone)]
154pub struct State<T> {
155    pub index: usize,
156    pub chars: CharacterClass,
157    pub next_states: Vec<usize>,
158    pub acceptance: bool,
159    pub start_capture: bool,
160    pub end_capture: bool,
161    pub metadata: Option<T>
162}
163
164impl<T> PartialEq for State<T> {
165    fn eq(&self, other: &State<T>) -> bool {
166        self.index == other.index
167    }
168}
169
170impl<T> State<T> {
171    pub fn new(index: usize, chars: CharacterClass) -> State<T> {
172        State {
173            index: index,
174            chars: chars,
175            next_states: Vec::new(),
176            acceptance: false,
177            start_capture: false,
178            end_capture: false,
179            metadata: None,
180        }
181    }
182}
183
184pub struct Match<'a> {
185    pub state: usize,
186    pub captures: Vec<&'a str>,
187}
188
189impl<'a> Match<'a> {
190    pub fn new<'b>(state: usize, captures: Vec<&'b str>) -> Match<'b> {
191        Match{ state: state, captures: captures }
192    }
193}
194
195#[derive(Clone)]
196pub struct NFA<T> {
197    states: Vec<State<T>>,
198    start_capture: Vec<bool>,
199    end_capture: Vec<bool>,
200    acceptance: Vec<bool>,
201}
202
203impl<T> NFA<T> {
204    pub fn new() -> NFA<T> {
205        let root = State::new(0, CharacterClass::any());
206        NFA {
207            states: vec![root],
208            start_capture: vec![false],
209            end_capture: vec![false],
210            acceptance: vec![false],
211        }
212    }
213
214    pub fn process<'a, I, F>(&self, string: &'a str, mut ord: F)
215                             -> Result<Match<'a>, String>
216        where I: Ord, F: FnMut(usize) -> I
217    {
218            let mut threads = vec![Thread::new()];
219
220            for (i, char) in string.chars().enumerate() {
221                let next_threads = self.process_char(threads, char, i);
222
223                if next_threads.is_empty() {
224                    return Err(format!("Couldn't process {}", string));
225                }
226
227                threads = next_threads;
228            }
229
230            let returned = threads.into_iter().filter(|thread| {
231                self.get(thread.state).acceptance
232            });
233
234            let thread = returned.fold(None, |prev, y| {
235                let y_v = ord(y.state);
236                match prev {
237                    None => Some((y_v, y)),
238                    Some((x_v, x)) => {
239                        if x_v < y_v {Some((y_v, y))} else {Some((x_v, x))}
240                    }
241                }
242            }).map(|p| p.1);
243
244            match thread {
245                None => Err("The string was exhausted before reaching an \
246                             acceptance state".to_string()),
247                Some(mut thread) => {
248                    if thread.capture_begin.is_some() {
249                        thread.end_capture(string.len());
250                    }
251                    let state = self.get(thread.state);
252                    Ok(Match::new(state.index, thread.extract(string)))
253                }
254            }
255        }
256
257    #[inline]
258    fn process_char<'a>(&self, threads: Vec<Thread>,
259                        char: char, pos: usize) -> Vec<Thread> {
260        let mut returned = Vec::with_capacity(threads.len());
261
262        for mut thread in threads.into_iter() {
263            let current_state = self.get(thread.state);
264
265            let mut count = 0;
266            let mut found_state = 0;
267
268            for &index in current_state.next_states.iter() {
269                let state = &self.states[index];
270
271                if state.chars.matches(char) {
272                    count += 1;
273                    found_state = index;
274                }
275            }
276
277            if count == 1 {
278                thread.state = found_state;
279                capture(self, &mut thread, current_state.index, found_state, pos);
280                returned.push(thread);
281                continue;
282            }
283
284            for &index in current_state.next_states.iter() {
285                let state = &self.states[index];
286                if state.chars.matches(char) {
287                    let mut thread = fork_thread(&thread, state);
288                    capture(self, &mut thread, current_state.index, index, pos);
289                    returned.push(thread);
290                }
291            }
292
293        }
294
295        returned
296    }
297
298    #[inline]
299    pub fn get<'a>(&'a self, state: usize) -> &'a State<T> {
300        &self.states[state]
301    }
302
303    pub fn get_mut<'a>(&'a mut self, state: usize) -> &'a mut State<T> {
304        &mut self.states[state]
305    }
306
307    pub fn put(&mut self, index: usize, chars: CharacterClass) -> usize {
308        {
309            let state = self.get(index);
310
311            for &index in state.next_states.iter() {
312                let state = self.get(index);
313                if state.chars == chars {
314                    return index;
315                }
316            }
317        }
318
319        let state = self.new_state(chars);
320        self.get_mut(index).next_states.push(state);
321        state
322    }
323
324    pub fn put_state(&mut self, index: usize, child: usize) {
325        if !self.states[index].next_states.contains(&child) {
326            self.get_mut(index).next_states.push(child);
327        }
328    }
329
330    pub fn acceptance(&mut self, index: usize) {
331        self.get_mut(index).acceptance = true;
332        self.acceptance[index] = true;
333    }
334
335    pub fn start_capture(&mut self, index: usize) {
336        self.get_mut(index).start_capture = true;
337        self.start_capture[index] = true;
338    }
339
340    pub fn end_capture(&mut self, index: usize) {
341        self.get_mut(index).end_capture = true;
342        self.end_capture[index] = true;
343    }
344
345    pub fn metadata(&mut self, index: usize, metadata: T) {
346        self.get_mut(index).metadata = Some(metadata);
347    }
348
349    fn new_state(&mut self, chars: CharacterClass) -> usize {
350        let index = self.states.len();
351        let state = State::new(index, chars);
352        self.states.push(state);
353
354        self.acceptance.push(false);
355        self.start_capture.push(false);
356        self.end_capture.push(false);
357
358        index
359    }
360}
361
362#[inline]
363fn fork_thread<T>(thread: &Thread, state: &State<T>) -> Thread {
364    let mut new_trace = thread.clone();
365    new_trace.state = state.index;
366    new_trace
367}
368
369#[inline]
370fn capture<T>(nfa: &NFA<T>, thread: &mut Thread, current_state: usize,
371              next_state: usize, pos: usize) {
372    if thread.capture_begin == None && nfa.start_capture[next_state] {
373        thread.start_capture(pos);
374    }
375
376    if thread.capture_begin != None && nfa.end_capture[current_state] &&
377        next_state > current_state {
378            thread.end_capture(pos);
379        }
380}
381
382#[test]
383fn basic_test() {
384    let mut nfa = NFA::<()>::new();
385    let a = nfa.put(0, CharacterClass::valid("h"));
386    let b = nfa.put(a, CharacterClass::valid("e"));
387    let c = nfa.put(b, CharacterClass::valid("l"));
388    let d = nfa.put(c, CharacterClass::valid("l"));
389    let e = nfa.put(d, CharacterClass::valid("o"));
390    nfa.acceptance(e);
391
392    let m = nfa.process("hello", |a| a);
393
394    assert!(m.unwrap().state == e, "You didn't get the right final state");
395}
396
397#[test]
398fn multiple_solutions() {
399    let mut nfa = NFA::<()>::new();
400    let a1 = nfa.put(0, CharacterClass::valid("n"));
401    let b1 = nfa.put(a1, CharacterClass::valid("e"));
402    let c1 = nfa.put(b1, CharacterClass::valid("w"));
403    nfa.acceptance(c1);
404
405    let a2 = nfa.put(0, CharacterClass::invalid(""));
406    let b2 = nfa.put(a2, CharacterClass::invalid(""));
407    let c2 = nfa.put(b2, CharacterClass::invalid(""));
408    nfa.acceptance(c2);
409
410    let m = nfa.process("new", |a| a);
411
412    assert!(m.unwrap().state == c2, "The two states were not found");
413}
414
415#[test]
416fn multiple_paths() {
417    let mut nfa = NFA::<()>::new();
418    let a = nfa.put(0, CharacterClass::valid("t"));   // t
419    let b1 = nfa.put(a, CharacterClass::valid("h"));  // th
420    let c1 = nfa.put(b1, CharacterClass::valid("o")); // tho
421    let d1 = nfa.put(c1, CharacterClass::valid("m")); // thom
422    let e1 = nfa.put(d1, CharacterClass::valid("a")); // thoma
423    let f1 = nfa.put(e1, CharacterClass::valid("s")); // thomas
424
425    let b2 = nfa.put(a, CharacterClass::valid("o"));  // to
426    let c2 = nfa.put(b2, CharacterClass::valid("m")); // tom
427
428    nfa.acceptance(f1);
429    nfa.acceptance(c2);
430
431    let thomas = nfa.process("thomas", |a| a);
432    let tom = nfa.process("tom", |a| a);
433    let thom = nfa.process("thom", |a| a);
434    let nope = nfa.process("nope", |a| a);
435
436    assert!(thomas.unwrap().state == f1, "thomas was parsed correctly");
437    assert!(tom.unwrap().state == c2, "tom was parsed correctly");
438    assert!(thom.is_err(), "thom didn't reach an acceptance state");
439    assert!(nope.is_err(), "nope wasn't parsed");
440}
441
442#[test]
443fn repetitions() {
444    let mut nfa = NFA::<()>::new();
445    let a = nfa.put(0, CharacterClass::valid("p"));   // p
446    let b = nfa.put(a, CharacterClass::valid("o"));   // po
447    let c = nfa.put(b, CharacterClass::valid("s"));   // pos
448    let d = nfa.put(c, CharacterClass::valid("t"));   // post
449    let e = nfa.put(d, CharacterClass::valid("s"));   // posts
450    let f = nfa.put(e, CharacterClass::valid("/"));   // posts/
451    let g = nfa.put(f, CharacterClass::invalid("/")); // posts/[^/]
452    nfa.put_state(g, g);
453
454    nfa.acceptance(g);
455
456    let post = nfa.process("posts/1", |a| a);
457    let new_post = nfa.process("posts/new", |a| a);
458    let invalid = nfa.process("posts/", |a| a);
459
460    assert!(post.unwrap().state == g, "posts/1 was parsed");
461    assert!(new_post.unwrap().state == g, "posts/new was parsed");
462    assert!(invalid.is_err(), "posts/ was invalid");
463}
464
465#[test]
466fn repetitions_with_ambiguous() {
467    let mut nfa = NFA::<()>::new();
468    let a  = nfa.put(0, CharacterClass::valid("p"));   // p
469    let b  = nfa.put(a, CharacterClass::valid("o"));   // po
470    let c  = nfa.put(b, CharacterClass::valid("s"));   // pos
471    let d  = nfa.put(c, CharacterClass::valid("t"));   // post
472    let e  = nfa.put(d, CharacterClass::valid("s"));   // posts
473    let f  = nfa.put(e, CharacterClass::valid("/"));   // posts/
474    let g1 = nfa.put(f, CharacterClass::invalid("/")); // posts/[^/]
475    let g2 = nfa.put(f, CharacterClass::valid("n"));   // posts/n
476    let h2 = nfa.put(g2, CharacterClass::valid("e"));  // posts/ne
477    let i2 = nfa.put(h2, CharacterClass::valid("w"));  // posts/new
478
479    nfa.put_state(g1, g1);
480
481    nfa.acceptance(g1);
482    nfa.acceptance(i2);
483
484    let post = nfa.process("posts/1", |a| a);
485    let ambiguous = nfa.process("posts/new", |a| a);
486    let invalid = nfa.process("posts/", |a| a);
487
488    assert!(post.unwrap().state == g1, "posts/1 was parsed");
489    assert!(ambiguous.unwrap().state == i2, "posts/new was ambiguous");
490    assert!(invalid.is_err(), "posts/ was invalid");
491}
492
493#[test]
494fn captures() {
495    let mut nfa = NFA::<()>::new();
496    let a = nfa.put(0, CharacterClass::valid("n"));
497    let b = nfa.put(a, CharacterClass::valid("e"));
498    let c = nfa.put(b, CharacterClass::valid("w"));
499
500    nfa.acceptance(c);
501    nfa.start_capture(a);
502    nfa.end_capture(c);
503
504    let post = nfa.process("new", |a| a);
505
506    assert_eq!(post.unwrap().captures, vec!["new"]);
507}
508
509#[test]
510fn capture_mid_match() {
511    let mut nfa = NFA::<()>::new();
512    let a = nfa.put(0, valid('p'));
513    let b = nfa.put(a, valid('/'));
514    let c = nfa.put(b, invalid('/'));
515    let d = nfa.put(c, valid('/'));
516    let e = nfa.put(d, valid('c'));
517
518    nfa.put_state(c, c);
519    nfa.acceptance(e);
520    nfa.start_capture(c);
521    nfa.end_capture(c);
522
523    let post = nfa.process("p/123/c", |a| a);
524
525    assert_eq!(post.unwrap().captures, vec!["123"]);
526}
527
528#[test]
529fn capture_multiple_captures() {
530    let mut nfa = NFA::<()>::new();
531    let a = nfa.put(0, valid('p'));
532    let b = nfa.put(a, valid('/'));
533    let c = nfa.put(b, invalid('/'));
534    let d = nfa.put(c, valid('/'));
535    let e = nfa.put(d, valid('c'));
536    let f = nfa.put(e, valid('/'));
537    let g = nfa.put(f, invalid('/'));
538
539    nfa.put_state(c, c);
540    nfa.put_state(g, g);
541    nfa.acceptance(g);
542
543    nfa.start_capture(c);
544    nfa.end_capture(c);
545
546    nfa.start_capture(g);
547    nfa.end_capture(g);
548
549    let post = nfa.process("p/123/c/456", |a| a);
550    assert_eq!(post.unwrap().captures, vec!["123", "456"]);
551}
552
553#[test]
554fn test_ascii_set() {
555    let mut set = CharSet::new();
556    set.insert('?');
557    set.insert('a');
558    set.insert('é');
559
560    assert!(set.contains('?'), "The set contains char 63");
561    assert!(set.contains('a'), "The set contains char 97");
562    assert!(set.contains('é'), "The set contains char 233");
563    assert!(!set.contains('q'), "The set does not contain q");
564    assert!(!set.contains('ü'), "The set does not contain ü");
565}
566
567#[allow(dead_code)]
568fn valid(char: char) -> CharacterClass {
569    CharacterClass::valid_char(char)
570}
571
572#[allow(dead_code)]
573fn invalid(char: char) -> CharacterClass {
574    CharacterClass::invalid_char(char)
575}