treebender/
rules.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::rc::Rc;
4
5use crate::featurestructure::NodeRef;
6use crate::utils::Err;
7
8#[derive(Debug, Copy, Clone, PartialEq, Eq)]
9pub enum ProductionKind {
10  Terminal,
11  Nonterminal,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Production {
16  pub kind: ProductionKind,
17  pub symbol: String,
18}
19
20impl Production {
21  pub fn new_terminal(symbol: String) -> Self {
22    Self {
23      kind: ProductionKind::Terminal,
24      symbol,
25    }
26  }
27
28  pub fn new_nonterminal(symbol: String) -> Self {
29    Self {
30      kind: ProductionKind::Nonterminal,
31      symbol,
32    }
33  }
34
35  pub fn is_terminal(&self) -> bool {
36    self.kind == ProductionKind::Terminal
37  }
38
39  pub fn is_nonterminal(&self) -> bool {
40    self.kind == ProductionKind::Nonterminal
41  }
42}
43
44impl fmt::Display for Production {
45  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46    write!(f, "{}", self.symbol)
47  }
48}
49
50#[derive(Debug, PartialEq)]
51pub struct Rule {
52  pub symbol: String,
53  pub features: NodeRef,
54  pub productions: Vec<Production>,
55}
56
57impl Rule {
58  pub fn len(&self) -> usize {
59    self.productions.len()
60  }
61
62  pub fn is_empty(&self) -> bool {
63    self.len() == 0
64  }
65}
66
67impl std::fmt::Display for Rule {
68  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69    write!(f, "{}{} ->", self.symbol, self.features)?;
70    for p in self.productions.iter() {
71      write!(f, " {}", p)?;
72    }
73    Ok(())
74  }
75}
76
77#[derive(Debug)]
78pub struct Grammar {
79  pub start: String,
80  pub rules: HashMap<String, Vec<Rc<Rule>>>,
81  nullables: HashSet<String>,
82  nonterminals: HashSet<String>,
83}
84
85impl std::fmt::Display for Grammar {
86  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87    writeln!(f, "//** start: {}", self.start)?;
88    write!(f, "//** nonterminals:")?;
89    for nt in self.nonterminals.iter() {
90      write!(f, " {}", nt)?;
91    }
92    writeln!(f)?;
93
94    write!(f, "//** nullables:")?;
95    for nt in self.nullables.iter() {
96      write!(f, " {}", nt)?;
97    }
98    writeln!(f)?;
99
100    for rule in self.rules.values().flatten() {
101      writeln!(f, "{}\n", rule)?;
102    }
103
104    Ok(())
105  }
106}
107
108impl Grammar {
109  pub fn new(rules: Vec<Rule>) -> Result<Self, Err> {
110    assert!(!rules.is_empty());
111
112    let nonterminals: HashSet<String> = rules.iter().map(|r| r.symbol.clone()).collect();
113    let start = rules[0].symbol.clone();
114
115    for r in rules.iter() {
116      for p in r.productions.iter() {
117        if p.is_nonterminal() && !nonterminals.contains(&p.symbol) {
118          return Err(format!("missing rules for nonterminal {}", p.symbol).into());
119        }
120      }
121    }
122
123    let rules: HashMap<String, Vec<Rc<Rule>>> =
124      rules.into_iter().fold(HashMap::new(), |mut map, rule| {
125        map
126          .entry(rule.symbol.clone())
127          .or_insert_with(Vec::new)
128          .push(Rc::new(rule));
129        map
130      });
131
132    let nullables = Self::find_nullables(&rules);
133
134    Ok(Self {
135      start,
136      rules,
137      nonterminals,
138      nullables,
139    })
140  }
141
142  pub fn is_nullable(&self, s: &str) -> bool {
143    self.nullables.contains(s)
144  }
145}
146
147impl Grammar {
148  fn rule_is_nullable(nullables: &HashSet<String>, rule: &Rule) -> bool {
149    rule.is_empty()
150      || rule
151        .productions
152        .iter()
153        .all(|p| p.is_nonterminal() && nullables.contains(&p.symbol))
154  }
155
156  fn find_nullables(rules: &HashMap<String, Vec<Rc<Rule>>>) -> HashSet<String> {
157    let mut nullables: HashSet<String> = HashSet::new();
158
159    let mut last_length = 1;
160    while last_length != nullables.len() {
161      last_length = nullables.len();
162      for r in rules.values().flatten() {
163        if !nullables.contains(&r.symbol) && Self::rule_is_nullable(&nullables, &r) {
164          nullables.insert(r.symbol.clone());
165        }
166      }
167    }
168
169    nullables
170  }
171}
172
173#[test]
174fn test_parse_grammar() {
175  let g: Grammar = r#"
176       S -> N[ case: nom, num: #1 ] IV[ num: #1 ]
177       S -> N[ case: nom, pron: #1, num: #2 ] TV[ num: #2 ] N[ case: acc, needs_pron: #1 ]
178       S -> N[ case: nom, num: #1 ] CV[ num: #num ] Comp S
179
180       N[ num: sg, pron: she ]     -> mary
181       IV[ num: top, tense: past ] -> fell
182       TV[ num: top, tense: past ] -> kissed
183       CV[ num: top, tense: past ] -> said
184       Comp -> that
185     "#
186  .parse()
187  .unwrap();
188
189  let nonterminals: HashSet<String> = ["S", "N", "IV", "TV", "CV", "Comp"]
190    .iter()
191    .map(|&s| s.to_string())
192    .collect();
193  assert_eq!(nonterminals, g.nonterminals);
194  assert_eq!(g.rules.len(), 6);
195
196  assert_eq!(g.rules.get("S").unwrap().len(), 3);
197  assert_eq!(g.rules.get("N").unwrap().len(), 1);
198  assert_eq!(g.rules.get("IV").unwrap().len(), 1);
199  assert_eq!(g.rules.get("TV").unwrap().len(), 1);
200  assert_eq!(g.rules.get("CV").unwrap().len(), 1);
201  assert_eq!(g.rules.get("Comp").unwrap().len(), 1);
202  assert!(g.rules.get("that").is_none());
203  assert!(g.rules.get("mary").is_none());
204}
205
206#[test]
207fn test_find_nullables() {
208  let g: Grammar = r#"
209      S -> A B
210      A -> c
211      B -> D D
212      D ->
213    "#
214  .parse()
215  .unwrap();
216
217  let nl: HashSet<String> = ["B", "D"].iter().map(|&s| s.to_string()).collect();
218  assert_eq!(g.nullables, nl);
219}