tract_data/dim/
parse.rs

1use super::*;
2use nom::branch::alt;
3use nom::bytes::complete::tag;
4use nom::character::complete::{alpha1, alphanumeric1, digit1, one_of};
5use nom::combinator::{all_consuming, map, map_res, recognize};
6use nom::multi::{many0, separated_list0};
7use nom::sequence::{delimited, pair, preceded, separated_pair};
8use nom::IResult;
9
10pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult<TDim> {
11    match all_consuming(|i| expr(symbol_table, i))(input) {
12        Ok(pair) => Ok(pair.1),
13        Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
14    }
15}
16
17pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult<Assertion> {
18    match all_consuming(|i| assertion(symbol_table, i))(input) {
19        Ok(pair) => Ok(pair.1),
20        Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
21    }
22}
23
24fn assertion<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> {
25    delimited(spaces, alt((
26        map(separated_pair(|i| expr(s, i), stag("=="), |i| expr(s, i)), |(a, b)| {
27            Assertion::Eq(a, b)
28        }),
29        map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| {
30            Assertion::LTE(a, b)
31        }),
32        map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| {
33            Assertion::GTE(a, b)
34        }),
35        map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| {
36            Assertion::LT(a, b)
37        }),
38        map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| {
39            Assertion::GT(a, b)
40        }),
41    )), spaces)(i)
42}
43
44fn expr<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, TDim> {
45    broadcast(symbol_table, i)
46}
47
48macro_rules! bin {
49    ($name: ident, $next: ident, $op: expr, $builder: expr) => {
50        fn $name<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
51            let s = symbol_table;
52            alt((map(separated_pair(|i| $next(s, i), stag($op), |i| $next(s, i)), $builder), |i| {
53                $next(s, i)
54            }))(input)
55        }
56    };
57}
58bin!(add, sub, "+", |(a, b)| a + b);
59bin!(sub, mul, "-", |(a, b)| a - b);
60bin!(mul, div, "*", |(a, b)| a * b);
61
62fn broadcast<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
63    let s = symbol_table;
64    alt((
65        map_res(separated_pair(|i| add(s, i), stag("#"), |i| add(s, i)), |(a, b)| a.broadcast(b)),
66        |i| add(s, i),
67    ))(input)
68}
69
70fn div<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
71    let s = symbol_table;
72    alt((map(separated_pair(|i| atom(s, i), stag("/"), numeric), |(a, b)| a / b), |i| atom(s, i)))(
73        input,
74    )
75}
76
77fn atom<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, TDim> {
78    alt((
79        map(numeric, TDim::Val),
80        map(|i| func(symbol_table, "min", i), TDim::Min),
81        map(|i| func(symbol_table, "max", i), TDim::Max),
82        map(|i| identifier(symbol_table, i), TDim::Sym),
83        map(pair(recognize(stag("-")), |i| atom(symbol_table, i)), |(_, dim)| dim * -1),
84        delimited(stag("("), |i| expr(symbol_table, i), stag(")")),
85    ))(i)
86}
87
88fn func<'i>(
89    symbol_table: &SymbolScope,
90    name: &'static str,
91    i: &'i str,
92) -> IResult<&'i str, Vec<TDim>> {
93    preceded(
94        stag(name),
95        delimited(stag("("), separated_list0(stag(","), |i| expr(symbol_table, i)), stag(")")),
96    )(i)
97}
98
99fn identifier<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, Symbol> {
100    map(recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))), |s| {
101        symbol_table.sym(s)
102    })(i)
103}
104
105fn numeric(i: &str) -> IResult<&str, i64> {
106    map_res(digit1, std::str::FromStr::from_str)(i)
107}
108
109fn spaces(i: &str) -> IResult<&str, ()> {
110    map(many0(one_of(" \t\n\r")), |_| ())(i)
111}
112
113fn spaced<'s, O, F>(it: F) -> impl FnMut(&'s str) -> IResult<&'s str, O>
114where
115    F: FnMut(&'s str) -> IResult<&'s str, O>,
116{
117    delimited(spaces, it, spaces)
118}
119
120pub(super) fn stag<'s>(t: &'static str) -> impl FnMut(&'s str) -> IResult<&'s str, &'s str> {
121    spaced(tag(t))
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    #[test]
129    fn parse_int() {
130        let table = SymbolScope::default();
131        assert_eq!(parse_tdim(&table, "12").unwrap(), TDim::Val(12));
132        assert_eq!(parse_tdim(&table, "-12").unwrap(), TDim::Val(-12));
133    }
134
135    #[test]
136    fn parse_sym() {
137        let table = SymbolScope::default();
138        assert_eq!(parse_tdim(&table, "x").unwrap(), TDim::Sym(table.sym("x")));
139        assert_eq!(
140            parse_tdim(&table, "-y").unwrap(),
141            TDim::MulInt(-1, Box::new(table.sym("y").into()))
142        );
143    }
144
145    #[test]
146    fn parse_bin() {
147        let table = SymbolScope::default();
148        assert_eq!(parse_tdim(&table, "1+2").unwrap(), 3.into());
149        assert_eq!(parse_tdim(&table, "1-2").unwrap(), (-1).into());
150        assert_eq!(parse_tdim(&table, "1*2").unwrap(), 2.into());
151        assert_eq!(parse_tdim(&table, "1/2").unwrap(), 0.into());
152    }
153
154    #[test]
155    fn parse_prio() {
156        let table = SymbolScope::default();
157        assert_eq!(parse_tdim(&table, "1+2*3").unwrap(), 7.into());
158        assert_eq!(parse_tdim(&table, "1*2+3").unwrap(), 5.into());
159    }
160
161    #[test]
162    fn parse_min() {
163        let table = SymbolScope::default();
164        assert_eq!(
165            parse_tdim(&table, "min(P,S)").unwrap(),
166            TDim::Min(vec!(table.sym("P").into(), table.sym("S").into()))
167        );
168    }
169
170    #[test]
171    fn parse_inequality_0() {
172        let table = SymbolScope::default();
173        assert_eq!(
174            parse_assertion(&table, "P+S<4096").unwrap(),
175            Assertion::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim())
176        );
177    }
178}