tiny_earley/
lib.rs

1#![cfg_attr(feature = "nightly", feature(test))]
2
3#![deny(unsafe_code)]
4
5pub mod forest;
6#[cfg(feature = "load")]
7pub mod load;
8
9use std::{collections::BinaryHeap, ops::Index};
10// #[cfg(feature = "debug")]
11use std::collections::{BTreeMap, BTreeSet};
12
13use self::forest::*;
14
15#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
16pub struct Symbol(u32);
17
18#[derive(Clone)]
19pub struct Grammar<const S: usize> {
20    rules: Vec<Rule>,
21    start_symbol: Symbol,
22    symbol_names: [&'static str; S],
23    gen_symbols: u32,
24    rule_id: usize,
25    lhs: Symbol,
26    nulling: BTreeMap<Symbol, usize>,
27    rhs_nulling: Vec<NullingEliminated>,
28}
29
30#[derive(Clone)]
31enum Side {
32    Left,
33    Right,
34    Both,
35}
36
37#[derive(Clone)]
38struct NullingEliminated {
39    rule_id: Option<usize>,
40    sym: Symbol,
41    side: Side,
42}
43
44#[derive(Clone, Copy, Debug, Eq, PartialEq)]
45struct Rule {
46    lhs: Symbol,
47    rhs0: Symbol,
48    rhs1: Option<Symbol>,
49    id: Option<usize>,
50}
51
52#[derive(Clone)]
53struct Tables<const S: usize> {
54    prediction_matrix: [[bool; S]; S],
55    start_symbol: Symbol,
56    rules: Vec<Rule>,
57    rules_by_rhs0: Vec<Rule>,
58    completions: Vec<Vec<PredictionTransition>>,
59    symbol_names: [&'static str; S],
60    gen_completions: Vec<PredictionTransition>,
61}
62
63#[derive(Copy, Clone, Debug, Default)]
64struct PredictionTransition {
65    symbol: Symbol,
66    top: Symbol,
67    dot: usize,
68    is_unary: bool,
69}
70
71// Recognizer
72
73#[derive(Clone)]
74pub struct Recognizer<const S: usize> {
75    tables: Tables<S>,
76    earley_chart: EarleyChart<S>,
77    complete: BinaryHeap<CompletedItem>,
78    pub forest: Forest,
79    pub finished_node: Option<NodeHandle>,
80}
81
82#[derive(Clone)]
83struct EarleyChart<const S: usize> {
84    predicted: Vec<[bool; S]>,
85    indices: Vec<usize>,
86    items: Vec<Item>,
87}
88
89impl<const S: usize> EarleyChart<S> {
90    fn next_set(&mut self, predicted: Option<[bool; S]>) {
91        self.predicted.push(predicted.unwrap_or([false; S]));
92        self.indices.push(self.items.len());
93    }
94
95    fn new() -> Self {
96        EarleyChart {
97            predicted: vec![],
98            indices: vec![],
99            items: vec![],
100        }
101    }
102
103    fn len(&self) -> usize {
104        self.indices.len()
105    }
106
107    fn predicted(&self, index: usize) -> &[bool] {
108        &self.predicted[index][..]
109    }
110
111    fn medial(&mut self) -> &mut Vec<Item> {
112        &mut self.items
113    }
114
115    fn next_medial(&self) -> &[Item] {
116        &self.items[self.indices[self.indices.len() - 1]..]
117    }
118
119    fn last(&self) -> EarleySetRef {
120        let items = &self.items[self.indices[self.indices.len() - 1]..];
121        let predicted = &self.predicted.last().unwrap()[..];
122        EarleySetRef { items, predicted }
123    }
124
125    fn next_to_last(&self) -> EarleySetRef {
126        let items =
127            &self.items[self.indices[self.indices.len() - 2]..self.indices[self.indices.len() - 1]];
128        let predicted = &self.predicted[self.predicted.len() - 2][..];
129        EarleySetRef { items, predicted }
130    }
131
132    fn last_mut(&mut self) -> EarleySetMut {
133        EarleySetMut {
134            items: &mut self.items[self.indices[self.indices.len() - 1]..],
135            predicted: &mut self.predicted.last_mut().unwrap()[..],
136        }
137    }
138}
139
140struct EarleySetMut<'a> {
141    items: &'a mut [Item],
142    predicted: &'a mut [bool],
143}
144
145struct EarleySetRef<'a> {
146    items: &'a [Item],
147    predicted: &'a [bool],
148}
149
150impl<const S: usize> Index<usize> for EarleyChart<S> {
151    type Output = [Item];
152    fn index(&self, index: usize) -> &Self::Output {
153        &self.items[self.indices[index]..self.indices[index + 1]]
154    }
155}
156
157#[derive(Clone)]
158struct EarleySet<const S: usize> {
159    predicted: [bool; S],
160    medial: Vec<Item>,
161}
162
163#[derive(Ord, PartialOrd, Eq, PartialEq, Clone)]
164struct Item {
165    postdot: Symbol,
166    dot: usize,
167    origin: usize,
168    node: NodeHandle,
169}
170
171#[derive(Clone, Copy, Ord, PartialOrd, Eq, PartialEq)]
172struct CompletedItem {
173    origin: usize,
174    dot: usize,
175    left_node: NodeHandle,
176    right_node: Option<NodeHandle>,
177}
178
179trait UnionWith {
180    fn union_with(&mut self, other: &[bool]);
181}
182
183impl<const S: usize> UnionWith for [bool; S] {
184    fn union_with(&mut self, other: &[bool]) {
185        for (dst, &src) in self.iter_mut().zip(other.iter()) {
186            *dst |= src;
187        }
188    }
189}
190
191impl<'a> UnionWith for &'a mut [bool] {
192    fn union_with(&mut self, other: &[bool]) {
193        for (dst, &src) in self.iter_mut().zip(other.iter()) {
194            *dst |= src;
195        }
196    }
197}
198
199impl Symbol {
200    fn usize(self) -> usize {
201        self.0 as usize
202    }
203}
204
205impl<const S: usize> EarleySet<S> {
206    fn new() -> Self {
207        EarleySet {
208            predicted: [false; S],
209            medial: vec![],
210        }
211    }
212}
213
214impl<const S: usize> Grammar<S> {
215    pub fn new(symbol_names: [&'static str; S], start_symbol: usize) -> Self {
216        Self {
217            rules: vec![],
218            start_symbol: Symbol(start_symbol as u32),
219            symbol_names,
220            gen_symbols: 0,
221            rule_id: 2,
222            lhs: Symbol(0),
223            nulling: BTreeMap::new(),
224            rhs_nulling: vec![],
225        }
226    }
227
228    pub fn symbols(&self) -> [Symbol; S] {
229        let mut result = [Symbol(0); S];
230        for (i, elem) in result.iter_mut().enumerate() {
231            *elem = Symbol(i as u32);
232        }
233        result
234    }
235
236    pub fn rule<const N: usize>(&mut self, lhs: Symbol, rhs: [Symbol; N]) -> &mut Self {
237        if N == 0 {
238            self.nulling.insert(lhs, self.rule_id);
239            self.rule_id += 1;
240            self.lhs = lhs;
241            return self;
242        }
243        let mut cur_rhs0 = rhs[0];
244        for i in 1 .. N - 1 {
245            let gensym = Symbol(self.gen_symbols + S as u32);
246            self.gen_symbols += 1;
247            self.rules.push(Rule {
248                lhs: gensym,
249                rhs0: cur_rhs0,
250                rhs1: Some(rhs[i]),
251                id: None,
252            });
253            cur_rhs0 = gensym;
254        }
255        self.rules.push(Rule {
256            lhs,
257            rhs0: cur_rhs0,
258            rhs1: if N == 1 { None } else { Some(rhs[N - 1]) },
259            id: Some(self.rule_id),
260        });
261        self.rule_id += 1;
262        self.lhs = lhs;
263        self
264    }
265
266    fn rhs<const N: usize>(&mut self, rhs: [Symbol; N]) -> &mut Self {
267        self.rule(self.lhs, rhs);
268        self
269    }
270
271    pub fn sort_rules(&mut self) {
272        self.rules.sort_by(|a, b| a.lhs.cmp(&b.lhs));
273    }
274
275    fn eliminate_nulling(&mut self) {
276        let mut new_rules = vec![];
277        let mut rules = self.rules.clone();
278        let mut change = true;
279        while change {
280            change = false;
281            for rule in &self.rules {
282                if let Some(_id) = self.nulling.get(&rule.rhs0) {
283                    if let Some(rhs1) = rule.rhs1 {
284                        new_rules.push(Rule {
285                            lhs: rule.lhs,
286                            rhs0: rhs1,
287                            rhs1: None,
288                            id: rule.id,
289                        });
290                        if let Some(_id) = self.nulling.get(&rhs1) {
291                            self.rhs_nulling.push(NullingEliminated {
292                                rule_id: rule.id,
293                                side: Side::Both,
294                                sym: rule.rhs0,
295                            });
296                            change |= self
297                                .nulling
298                                .insert(rule.lhs, rule.id.unwrap_or(0))
299                                .is_none();
300                        } else {
301                            self.rhs_nulling.push(NullingEliminated {
302                                rule_id: rule.id,
303                                side: Side::Left,
304                                sym: rule.rhs0,
305                            });
306                        }
307                    } else {
308                        // TODO rule.id
309                        self.rhs_nulling.push(NullingEliminated {
310                            rule_id: rule.id,
311                            side: Side::Left,
312                            sym: rule.rhs0,
313                        });
314                        change |= self
315                            .nulling
316                            .insert(rule.lhs, rule.id.unwrap_or(0))
317                            .is_none();
318                    }
319                }
320                if let Some(rhs1) = rule.rhs1 {
321                    if let Some(_id) = self.nulling.get(&rhs1) {
322                        self.rhs_nulling.push(NullingEliminated {
323                            rule_id: rule.id,
324                            side: Side::Right,
325                            sym: rhs1,
326                        });
327                        new_rules.push(Rule {
328                            lhs: rule.lhs,
329                            rhs0: rule.rhs0,
330                            rhs1: None,
331                            id: rule.id,
332                        });
333                    }
334                }
335            }
336        }
337        self.rules.extend(new_rules);
338    }
339
340    pub fn stringify_to_bnf(&self) -> String {
341        use std::fmt::Write;
342        let mut result = String::new();
343        for (i, rule) in self.rules.iter().enumerate() {
344            let tostr = |sym: Symbol| if sym.usize() >= S { format!("g{}({})", sym.usize() - S, sym.usize()) } else { format!("{}({})", self.symbol_names[sym.usize()], sym.usize()) };
345            let lhs = tostr(rule.lhs);
346            let rhs0 = tostr(rule.rhs0);
347            let rhs1 = if let Some(rhs1) = rule.rhs1 {
348                format!(" {}", tostr(rhs1))
349            } else {
350                "".to_string()
351            };
352            writeln!(&mut result, "{}: {} ::= {}{};", i, lhs, rhs0, rhs1).unwrap();
353        }
354        result
355    }
356}
357
358// Implementation for the recognizer.
359//
360// The recognizer has a chart of earley sets (Vec<EarleySet>) as well as the last set (next_set).
361//
362// A typical loop that utilizes the recognizer:
363//
364// - for character in string {
365// 1.   recognizer.begin_earleme();
366// 2.   recognizer.scan(token_to_symbol(character), values());
367//        2a. complete
368// 3.   recognizer.end_earleme();
369//        3a. self.complete_all_sums_entirely();
370//        3b. self.sort_medial_items();
371//        3c. self.prediction_pass();
372// - }
373//
374impl<const S: usize> Recognizer<S> {
375    pub fn new(grammar: &Grammar<S>) -> Self {
376        let mut result = Self {
377            tables: Tables::new(grammar),
378            earley_chart: EarleyChart::new(),
379            forest: Forest::new(grammar),
380            // complete: BinaryHeap::new_by_key(Box::new(|completed_item| (completed_item.origin, completed_item.dot))),
381            complete: BinaryHeap::with_capacity(64),
382            finished_node: None,
383        };
384        result.initialize();
385        result
386    }
387
388    fn initialize(&mut self) {
389        self.earley_chart.next_set(Some(
390            self.tables.prediction_matrix[self.tables.start_symbol.usize()],
391        ));
392        self.earley_chart.next_set(None);
393    }
394
395    pub fn scan(&mut self, terminal: Symbol, values: u32) {
396        let earleme = self.earley_chart.len() - 2;
397        let node = self.forest.leaf(terminal, earleme + 1, values);
398        self.complete(earleme, terminal, node);
399    }
400
401    pub fn end_earleme(&mut self) -> bool {
402        if self.is_exhausted() {
403            false
404        } else {
405            // Completion pass, which saves successful parses.
406            self.finished_node = None;
407            self.complete_all_sums_entirely();
408            // Do the rest.
409            self.sort_medial_items();
410            self.prediction_pass();
411            self.earley_chart.next_set(None);
412            true
413        }
414    }
415
416    pub fn is_exhausted(&self) -> bool {
417        self.earley_chart.next_medial().len() == 0 && self.complete.is_empty()
418    }
419
420    fn complete_all_sums_entirely(&mut self) {
421        while let Some(&ei) = self.complete.peek() {
422            let lhs_sym = self.tables.get_lhs(ei.dot);
423            let mut result_node = None;
424            while let Some(&ei2) = self.complete.peek() {
425                if ei.origin == ei2.origin && lhs_sym == self.tables.get_lhs(ei2.dot) {
426                    result_node = Some(self.forest.push_summand(ei2));
427                    self.complete.pop();
428                } else {
429                    break;
430                }
431            }
432            if ei.origin == 0 && lhs_sym == self.tables.start_symbol {
433                self.finished_node = Some(result_node.unwrap());
434            }
435            self.complete(ei.origin, lhs_sym, result_node.unwrap());
436        }
437    }
438
439    /// Sorts medial items with deduplication.
440    fn sort_medial_items(&mut self) {
441        // Build index by postdot
442        // These medial positions themselves are sorted by postdot symbol.
443        self.earley_chart.last_mut().items.sort_unstable();
444    }
445
446    fn prediction_pass(&mut self) {
447        // Iterate through medial items in the current set.
448        let mut last = self.earley_chart.last_mut();
449        let iter = last.items.iter();
450        // For each medial item in the current set, predict its postdot symbol.
451        for ei in iter {
452            if let Some(postdot) = self
453                .tables
454                .get_rhs1(ei.dot)
455                .filter(|postdot| !last.predicted[postdot.usize()])
456            {
457                // Prediction happens here. We would prefer to call `self.predict`, but we can't,
458                // because `self.medial` is borrowed by `iter`.
459                let source = &self.tables.prediction_matrix[postdot.usize()][..];
460                last.predicted.union_with(source);
461            }
462        }
463    }
464
465    fn complete(&mut self, earleme: usize, symbol: Symbol, node: NodeHandle) {
466        if symbol.usize() >= S {
467            self.complete_binary_predictions(earleme, symbol, node);
468        } else if self.earley_chart.predicted(earleme)[symbol.usize()] {
469            self.complete_medial_items(earleme, symbol, node);
470            self.complete_predictions(earleme, symbol, node);
471        }
472    }
473
474    fn complete_medial_items(&mut self, earleme: usize, symbol: Symbol, right_node: NodeHandle) {
475        let inner_start = {
476            // we use binary search to narrow down the range of items.
477            let set_idx = self.earley_chart[earleme]
478                .binary_search_by(|ei| (self.tables.get_rhs1(ei.dot), 1).cmp(&(Some(symbol), 0)));
479            match set_idx {
480                Ok(idx) | Err(idx) => idx,
481            }
482        };
483
484        let rhs1_eq = |ei: &&Item| self.tables.get_rhs1(ei.dot) == Some(symbol);
485        for item in self.earley_chart[earleme][inner_start..]
486            .iter()
487            .take_while(rhs1_eq)
488        {
489            self.complete.push(CompletedItem {
490                dot: item.dot,
491                origin: item.origin,
492                left_node: item.node,
493                right_node: Some(right_node),
494            });
495        }
496    }
497
498    fn complete_predictions(&mut self, earleme: usize, symbol: Symbol, node: NodeHandle) {
499        // println!("{:?}", slice);
500        for trans in &self.tables.completions[symbol.usize()] {
501            if self.earley_chart.predicted(earleme)[trans.top.usize()] {
502                if trans.is_unary {
503                    self.complete.push(CompletedItem {
504                        origin: earleme,
505                        dot: trans.dot,
506                        left_node: node,
507                        right_node: None,
508                    });
509                } else {
510                    self.earley_chart.medial().push(Item {
511                        origin: earleme,
512                        dot: trans.dot,
513                        node: node,
514                        postdot: self.tables.get_rhs1(trans.dot).unwrap(),
515                    });
516                }
517            }
518        }
519    }
520
521    fn complete_binary_predictions(&mut self, earleme: usize, symbol: Symbol, node: NodeHandle) {
522        let trans = self.tables.gen_completions[symbol.usize() - S];
523        if self.earley_chart.predicted(earleme)[trans.top.usize()] {
524            self.earley_chart.medial().push(Item {
525                origin: earleme,
526                dot: trans.dot,
527                node,
528                postdot: self.tables.get_rhs1(trans.dot).unwrap(),
529            });
530            if trans.is_unary {
531                self.complete.push(CompletedItem {
532                    origin: earleme,
533                    dot: trans.dot,
534                    left_node: node,
535                    right_node: None,
536                });
537            }
538        }
539    }
540
541    #[cfg(feature = "debug")]
542    pub fn log_last_earley_set(&self) {
543        let dots = self.dots_for_log(self.earley_chart.last());
544        for (rule_id, dots) in dots {
545            print!(
546                "{} ::= ",
547                self.tables.symbol_names[self.tables.get_lhs(rule_id).usize()]
548            );
549            if let Some(origins) = dots.get(&0) {
550                print!("{:?}", origins);
551            }
552            print!(
553                " {} ",
554                self.tables.symbol_names[self.tables.rules[rule_id].rhs0.usize()]
555            );
556            if let Some(origins) = dots.get(&1) {
557                print!("{:?}", origins);
558            }
559            if let Some(rhs1) = self.tables.get_rhs1(rule_id) {
560                print!(" {} ", self.tables.symbol_names[rhs1.usize()]);
561            }
562            println!();
563        }
564        println!();
565    }
566
567    pub fn earley_set_diff(&self) -> String {
568        let mut result = String::new();
569        use std::fmt::Write;
570        use std::collections::{BTreeMap, BTreeSet};
571        let dots_last_by_id = self.dots_for_log(self.earley_chart.next_to_last());
572        let mut dots_next_by_id = self.dots_for_log(self.earley_chart.last());
573        let mut rule_ids: BTreeSet<usize> = BTreeSet::new();
574        rule_ids.extend(dots_last_by_id.keys());
575        rule_ids.extend(dots_next_by_id.keys());
576        for item in self.complete.iter() {
577            let position = if self.tables.get_rhs1(item.dot).is_some() {
578                2
579            } else {
580                1
581            };
582            dots_next_by_id
583                .entry(item.dot)
584                .or_insert(BTreeMap::new())
585                .entry(position)
586                .or_insert(BTreeSet::new())
587                .insert(item.origin);
588        }
589        let mut joined: BTreeMap<usize, BTreeMap<usize, (BTreeSet<usize>, BTreeSet<_>)>> = BTreeMap::new();
590        for (rule_id, map) in dots_last_by_id {
591            for (pos, set) in map {
592                joined.entry(rule_id).or_insert(BTreeMap::new()).entry(pos).or_insert((BTreeSet::new(), BTreeSet::new())).0.extend(set);
593            }
594        }
595        for (rule_id, map) in dots_next_by_id {
596            for (pos, set) in map {
597                joined.entry(rule_id).or_insert(BTreeMap::new()).entry(pos).or_insert((BTreeSet::new(), BTreeSet::new())).1.extend(set);
598            }
599        }
600        let mut empty_diff = true;
601        for (rule_id, dots) in joined {
602            // let dots_last = dots_last_by_id.get(&rule_id).unwrap_or(BTreeMap::new());
603            // let dots_next = dots_next_by_id.get(&rule_id);
604            // if dots_last == dots_next {
605            //     continue;
606            // }
607            empty_diff = false;
608            write!(
609                result,
610                "diff {} ::= ",
611                self.tables.symbol_names[self.tables.get_top_lhs(rule_id).usize()]
612            );
613            if let Some(&(ref a, ref b)) = dots.get(&0) {
614                write!(result, "{:?}=>{:?}", a, b);
615            }
616            write!(
617                result,
618                " {} ",
619                self.tables.symbol_names[self.tables.get_top_rhs0(rule_id).usize()]
620            );
621            if let Some(&(ref a, ref b)) = dots.get(&1) {
622                write!(result, "{:?}=>{:?}", a, b);
623            }
624            if let Some(rhs1) = self.tables.get_rhs1(rule_id) {
625                write!(result, " {} ", self.tables.symbol_names[rhs1.usize()]);
626            }
627            if let Some(&(ref a, ref b)) = dots.get(&2) {
628                write!(result, "{:?}=>{:?}", a, b);
629            }
630            writeln!(result, "");
631        }
632        if empty_diff {
633            writeln!(result, "no diff");
634            writeln!(result, "");
635        } else {
636            writeln!(result, "");
637        }
638        result
639    }
640
641    #[cfg(feature = "debug")]
642    fn dots_for_log(&self, es: EarleySetRef) -> BTreeMap<usize, BTreeMap<usize, BTreeSet<usize>>> {
643        let mut dots = BTreeMap::new();
644        for (i, rule) in self.tables.rules.iter().enumerate() {
645            if es.predicted[self.tables.get_top_lhs(i).usize()] {
646                dots.entry(i)
647                    .or_insert(BTreeMap::new())
648                    .entry(0)
649                    .or_insert(BTreeSet::new())
650                    .insert(self.earley_chart.len() - 1);
651            }
652        }
653        for item in es.items {
654            dots.entry(item.dot)
655                .or_insert(BTreeMap::new())
656                .entry(1)
657                .or_insert(BTreeSet::new())
658                .insert(item.origin);
659        }
660        dots
661    }
662}
663
664impl<const S: usize> Tables<S> {
665    fn new(grammar: &Grammar<S>) -> Self {
666        let mut result = Self {
667            prediction_matrix: [[false; S]; S],
668            start_symbol: grammar.start_symbol,
669            rules: vec![],
670            rules_by_rhs0: vec![],
671            completions: vec![],
672            symbol_names: grammar.symbol_names,
673            gen_completions: vec![Default::default(); grammar.gen_symbols as usize],
674        };
675        result.populate(grammar);
676        result
677    }
678
679    fn populate(&mut self, grammar: &Grammar<S>) {
680        self.populate_rules(grammar);
681        self.populate_prediction_matrix(grammar);
682        self.populate_completions(grammar);
683    }
684
685    fn populate_prediction_matrix(&mut self, grammar: &Grammar<S>) {
686        for rule in &grammar.rules {
687            if rule.rhs0.usize() < S {
688                let mut top = rule.lhs;
689                while top.usize() >= S {
690                    // appears on only one rhs0
691                    let idx = self
692                        .rules_by_rhs0
693                        .binary_search_by_key(&top, |elem| elem.rhs0)
694                        .expect("lhs not found");
695                    top = self.rules_by_rhs0[idx].lhs;
696                }
697                self.prediction_matrix[top.usize()][rule.rhs0.usize()] = true;
698            }
699        }
700        self.reflexive_closure();
701        self.transitive_closure();
702    }
703
704    fn reflexive_closure(&mut self) {
705        for i in 0..S {
706            self.prediction_matrix[i][i] = true;
707        }
708    }
709
710    fn transitive_closure(&mut self) {
711        for pos in 0..S {
712            let (rows0, rows1) = self.prediction_matrix.split_at_mut(pos);
713            let (rows1, rows2) = rows1.split_at_mut(1);
714            for dst_row in rows0.iter_mut().chain(rows2.iter_mut()) {
715                if dst_row[pos] {
716                    dst_row.union_with(&rows1[0]);
717                }
718            }
719        }
720    }
721
722    fn populate_rules(&mut self, grammar: &Grammar<S>) {
723        self.rules = grammar.rules.clone();
724        self.rules_by_rhs0 = self.rules.clone();
725        self.rules_by_rhs0.sort_by_key(|rule| rule.rhs0);
726    }
727
728    fn populate_completions(&mut self, grammar: &Grammar<S>) {
729        self.completions
730            .resize(S, vec![]);
731        for (i, rule) in grammar.rules.iter().enumerate() {
732            let rhs0 = rule.rhs0.usize();
733            let mut top = rule.lhs;
734            while top.usize() >= S {
735                // appears on only one rhs0
736                let idx = self
737                    .rules_by_rhs0
738                    .binary_search_by_key(&top, |elem| elem.rhs0)
739                    .expect("lhs not found");
740                top = self.rules_by_rhs0[idx].lhs;
741            }
742            let transition = PredictionTransition {
743                symbol: rule.lhs,
744                top,
745                dot: i,
746                is_unary: rule.rhs1.is_none(),
747            };
748            if rhs0 >= S {
749                if rule.rhs1.is_some() {
750                    self.gen_completions[rhs0 - S] = transition;
751                } else {
752                    self.gen_completions[rhs0 - S].is_unary = true;
753                }
754            } else {
755                self.completions[rhs0].push(transition);
756            }
757        }
758    }
759
760    fn get_rhs1(&self, n: usize) -> Option<Symbol> {
761        self.rules.get(n).and_then(|rule| rule.rhs1)
762    }
763
764    fn get_lhs(&self, n: usize) -> Symbol {
765        self.rules[n].lhs
766    }
767
768    #[cfg(feature = "debug")]
769    fn get_top_lhs(&self, dot: usize) -> Symbol {
770        let mut top = self.rules[dot].lhs;
771        while top.usize() >= S {
772            // appears on only one rhs0
773            let idx = self
774                .rules_by_rhs0
775                .binary_search_by_key(&top, |elem| elem.rhs0)
776                .expect("lhs not found");
777            top = self.rules_by_rhs0[idx].lhs;
778        }
779        top
780    }
781
782    #[cfg(feature = "debug")]
783    fn get_top_rhs0(&self, dot: usize) -> Symbol {
784        let mut top = self.rules[dot].rhs0;
785        while top.usize() >= S {
786            // appears on only one rhs0
787            let idx = self
788                .rules_by_rhs0
789                .binary_search_by_key(&top, |elem| elem.rhs0)
790                .expect("lhs not found");
791            top = self.rules_by_rhs0[idx].lhs;
792        }
793        top
794    }
795}
796
797#[derive(Clone, Debug)]
798pub enum Value {
799    Digits(String),
800    Float(f64),
801    None,
802}
803
804#[derive(Clone)]
805pub struct CalcRecognizer {
806    grammar: Grammar<13>,
807    recognizer: Recognizer<13>,
808}
809
810pub fn calc_recognizer() -> CalcRecognizer {
811    let mut grammar = Grammar::new(
812        [
813            "sum", "factor", "op_mul", "op_div", "lparen", "rparen", "expr_sym", "op_minus",
814            "op_plus", "number", "whole", "digit", "dot",
815        ],
816        0,
817    );
818    let [sum, factor, op_mul, op_div, lparen, rparen, expr_sym, op_minus, op_plus, number, whole, digit, dot] =
819        grammar.symbols();
820    // sum ::= sum [+-] factor
821    // sum ::= factor
822    // factor ::= factor [*/] expr
823    // factor ::= expr
824    // expr ::= '(' sum ')' | '-' expr | number
825    // number ::= whole | whole '.' whole
826    // whole ::= whole [0-9] | [0-9]
827    grammar.rule(sum, [sum, op_plus, factor]);
828    grammar.rule(sum, [sum, op_minus, factor]);
829    grammar.rule(sum, [factor]);
830    grammar.rule(factor, [factor, op_mul, expr_sym]);
831    grammar.rule(factor, [factor, op_div, expr_sym]);
832    grammar.rule(factor, [expr_sym]);
833
834    grammar.rule(expr_sym, [lparen, sum, rparen]);
835    grammar.rule(expr_sym, [op_minus, expr_sym]);
836    grammar.rule(expr_sym, [number]);
837    grammar.rule(number, [whole]);
838    grammar.rule(number, [whole, dot, whole]);
839    grammar.rule(whole, [whole, digit]);
840    grammar.rule(whole, [digit]);
841    grammar.sort_rules();
842    let recognizer = Recognizer::new(&grammar);
843    CalcRecognizer {
844        recognizer,
845        grammar,
846    }
847}
848
849struct E {
850    symbols: [Symbol; 13],
851}
852
853impl self::forest::Eval for E {
854    type Elem = Value;
855
856    fn leaf(&self, terminal: Symbol, values: u32) -> Self::Elem {
857        let [sum, factor, op_mul, op_div, lparen, rparen, _expr_sym, op_minus, op_plus, _number, _whole, digit, dot] =
858            self.symbols;
859        if terminal == digit {
860            Value::Digits((values as u8 as char).to_string())
861        } else {
862            Value::None
863        }
864    }
865
866    fn product(&self, action: u32, args: Vec<Self::Elem>) -> Self::Elem {
867        let [sum, factor, op_mul, op_div, lparen, rparen, _expr_sym, op_minus, op_plus, _number, _whole, digit, dot] =
868            self.symbols;
869        // let mut iter = args.into_iter();
870        match (
871            action,
872            args.get(0).cloned().unwrap_or(Value::None),
873            args.get(1).cloned().unwrap_or(Value::None),
874            args.get(2).cloned().unwrap_or(Value::None),
875        ) {
876            (2, Value::Float(left), _, Value::Float(right)) => Value::Float(left + right),
877            (3, Value::Float(left), _, Value::Float(right)) => Value::Float(left - right),
878            (4, val, Value::None, Value::None) => val,
879            (5, Value::Float(left), _, Value::Float(right)) => Value::Float(left * right),
880            (6, Value::Float(left), _, Value::Float(right)) => Value::Float(left / right),
881            (7, val, Value::None, Value::None) => val,
882            (8, _, val, _) => val,
883            (9, _, Value::Float(num), Value::None) => Value::Float(-num),
884            (10, Value::Digits(digits), Value::None, Value::None) => {
885                Value::Float(digits.parse::<f64>().unwrap())
886            }
887            (11, val @ Value::Digits(..), _, _) => val,
888            (12, Value::Digits(before_dot), _, Value::Digits(after_dot)) => {
889                let mut digits = before_dot;
890                digits.push('.');
891                digits.push_str(&after_dot[..]);
892                Value::Digits(digits)
893            }
894            (13, Value::Digits(mut num), Value::Digits(digit), _) => {
895                num.push_str(&digit[..]);
896                Value::Digits(num)
897            }
898            (14, val @ Value::Digits(..), _, _) => val,
899            args => panic!("unknown rule id {:?} or args {:?}", action, args),
900        }
901    }
902}
903
904impl CalcRecognizer {
905    pub fn parse(&mut self, expr: &str) -> f64 {
906        let [sum, factor, op_mul, op_div, lparen, rparen, _expr_sym, op_minus, op_plus, _number, _whole, digit, dot] =
907            self.grammar.symbols();
908
909        let symbols = self.grammar.symbols();
910
911        for (i, ch) in expr.chars().enumerate() {
912            let terminal = match ch {
913                '-' => op_minus,
914                '.' => dot,
915                '0'..='9' => digit,
916                '(' => lparen,
917                ')' => rparen,
918                '*' => op_mul,
919                '/' => op_div,
920                '+' => op_plus,
921                ' ' => continue,
922                other => panic!("invalid character {}", other),
923            };
924            self.recognizer.scan(terminal, ch as u32);
925            let success = self.recognizer.end_earleme();
926            #[cfg(feature = "debug")]
927            if !success {
928                println!("{}", self.recognizer.earley_set_diff());
929            }
930            assert!(success, "parse failed at character {}", i);
931        }
932        let finished_node = self.recognizer.finished_node.expect("parse failed");
933        let result = self
934            .recognizer
935            .forest
936            .evaluator(E { symbols })
937            .evaluate(finished_node);
938        if let Value::Float(num) = result {
939            num
940        } else {
941            panic!("evaluation failed {:?}", result)
942        }
943    }
944}
945
946pub fn calc(expr: &str) -> f64 {
947    let mut recognizer = calc_recognizer();
948    recognizer.parse(expr)
949}
950
951#[cfg(test)]
952mod test {
953    use super::{calc, calc_recognizer};
954
955    #[test]
956    fn test_parse() {
957        assert_eq!(calc("1.0 + 2.0"), 3.0);
958    }
959
960    #[test]
961    fn test_parse_big() {
962        assert_eq!(calc("1.0 + (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 \
963            * 3.01234234 + (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 * 3.01234234 + \
964            (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 * 3.01234234"), 79.07405404);
965    }
966}
967
968#[cfg(all(feature = "nightly", test))]
969mod bench {
970    extern crate test;
971    use test::Bencher;
972    use super::{calc, calc_recognizer};
973
974    #[bench]
975    fn bench_parser(bench: &mut Bencher) {
976        let recognizer = calc_recognizer();
977        bench.bytes = 9;
978        bench.iter(|| {
979            let mut parser = recognizer.clone();
980            parser.parse("1.0 + 2.0")
981        });
982    }
983
984    #[bench]
985    fn bench_parser2(bench: &mut Bencher) {
986        let recognizer = calc_recognizer();
987        bench.bytes = 76;
988        bench.iter(|| {
989            let mut parser = recognizer.clone();
990            parser.parse("1.0 + 2.0 * 3.0 + 1.0 + 2.0 * 3.0 + 1.0 + 2.0 * 3.0 / 1.0 + 2.0 * 3.01234234")
991        });
992    }
993
994    #[bench]
995    fn bench_parser3(bench: &mut Bencher) {
996        let recognizer = calc_recognizer();
997        bench.bytes = 68 + 92 + 74;
998        bench.iter(|| {
999            let mut parser = recognizer.clone();
1000            parser.parse("1.0 + (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 \
1001            * 3.01234234 + (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 * 3.01234234 + \
1002            (2.0 * 3.0 + (1.0 + 2.0 * 3.0) + 1.0) + 2.0 * 3.0 / 1.0 + 2.0 * 3.01234234")
1003        });
1004    }
1005}
1006
1007#[cfg(all(feature = "nightly", test))]
1008mod bench_c;