1use super::*;
2use nom::branch::alt;
3use nom::bytes::complete::tag;
4use nom::character::complete::{alpha1, alphanumeric1, digit1, one_of};
5use nom::combinator::{all_consuming, map, map_res, recognize};
6use nom::multi::{many0, separated_list0};
7use nom::sequence::{delimited, pair, preceded, separated_pair};
8use nom::IResult;
9
10pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult<TDim> {
11 match all_consuming(|i| expr(symbol_table, i))(input) {
12 Ok(pair) => Ok(pair.1),
13 Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
14 }
15}
16
17pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult<Assertion> {
18 match all_consuming(|i| assertion(symbol_table, i))(input) {
19 Ok(pair) => Ok(pair.1),
20 Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
21 }
22}
23
24fn assertion<'i>(s: &SymbolScope, i: &'i str) -> IResult<&'i str, Assertion> {
25 delimited(spaces, alt((
26 map(separated_pair(|i| expr(s, i), stag("=="), |i| expr(s, i)), |(a, b)| {
27 Assertion::Eq(a, b)
28 }),
29 map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| {
30 Assertion::LTE(a, b)
31 }),
32 map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| {
33 Assertion::GTE(a, b)
34 }),
35 map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| {
36 Assertion::LT(a, b)
37 }),
38 map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| {
39 Assertion::GT(a, b)
40 }),
41 )), spaces)(i)
42}
43
44fn expr<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, TDim> {
45 broadcast(symbol_table, i)
46}
47
48macro_rules! bin {
49 ($name: ident, $next: ident, $op: expr, $builder: expr) => {
50 fn $name<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
51 let s = symbol_table;
52 alt((map(separated_pair(|i| $next(s, i), stag($op), |i| $next(s, i)), $builder), |i| {
53 $next(s, i)
54 }))(input)
55 }
56 };
57}
58bin!(add, sub, "+", |(a, b)| a + b);
59bin!(sub, mul, "-", |(a, b)| a - b);
60bin!(mul, div, "*", |(a, b)| a * b);
61
62fn broadcast<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
63 let s = symbol_table;
64 alt((
65 map_res(separated_pair(|i| add(s, i), stag("#"), |i| add(s, i)), |(a, b)| a.broadcast(b)),
66 |i| add(s, i),
67 ))(input)
68}
69
70fn div<'i>(symbol_table: &SymbolScope, input: &'i str) -> IResult<&'i str, TDim> {
71 let s = symbol_table;
72 alt((map(separated_pair(|i| atom(s, i), stag("/"), numeric), |(a, b)| a / b), |i| atom(s, i)))(
73 input,
74 )
75}
76
77fn atom<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, TDim> {
78 alt((
79 map(numeric, TDim::Val),
80 map(|i| func(symbol_table, "min", i), TDim::Min),
81 map(|i| func(symbol_table, "max", i), TDim::Max),
82 map(|i| identifier(symbol_table, i), TDim::Sym),
83 map(pair(recognize(stag("-")), |i| atom(symbol_table, i)), |(_, dim)| dim * -1),
84 delimited(stag("("), |i| expr(symbol_table, i), stag(")")),
85 ))(i)
86}
87
88fn func<'i>(
89 symbol_table: &SymbolScope,
90 name: &'static str,
91 i: &'i str,
92) -> IResult<&'i str, Vec<TDim>> {
93 preceded(
94 stag(name),
95 delimited(stag("("), separated_list0(stag(","), |i| expr(symbol_table, i)), stag(")")),
96 )(i)
97}
98
99fn identifier<'i>(symbol_table: &SymbolScope, i: &'i str) -> IResult<&'i str, Symbol> {
100 map(recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))), |s| {
101 symbol_table.sym(s)
102 })(i)
103}
104
105fn numeric(i: &str) -> IResult<&str, i64> {
106 map_res(digit1, std::str::FromStr::from_str)(i)
107}
108
109fn spaces(i: &str) -> IResult<&str, ()> {
110 map(many0(one_of(" \t\n\r")), |_| ())(i)
111}
112
113fn spaced<'s, O, F>(it: F) -> impl FnMut(&'s str) -> IResult<&'s str, O>
114where
115 F: FnMut(&'s str) -> IResult<&'s str, O>,
116{
117 delimited(spaces, it, spaces)
118}
119
120pub(super) fn stag<'s>(t: &'static str) -> impl FnMut(&'s str) -> IResult<&'s str, &'s str> {
121 spaced(tag(t))
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 #[test]
129 fn parse_int() {
130 let table = SymbolScope::default();
131 assert_eq!(parse_tdim(&table, "12").unwrap(), TDim::Val(12));
132 assert_eq!(parse_tdim(&table, "-12").unwrap(), TDim::Val(-12));
133 }
134
135 #[test]
136 fn parse_sym() {
137 let table = SymbolScope::default();
138 assert_eq!(parse_tdim(&table, "x").unwrap(), TDim::Sym(table.sym("x")));
139 assert_eq!(
140 parse_tdim(&table, "-y").unwrap(),
141 TDim::MulInt(-1, Box::new(table.sym("y").into()))
142 );
143 }
144
145 #[test]
146 fn parse_bin() {
147 let table = SymbolScope::default();
148 assert_eq!(parse_tdim(&table, "1+2").unwrap(), 3.into());
149 assert_eq!(parse_tdim(&table, "1-2").unwrap(), (-1).into());
150 assert_eq!(parse_tdim(&table, "1*2").unwrap(), 2.into());
151 assert_eq!(parse_tdim(&table, "1/2").unwrap(), 0.into());
152 }
153
154 #[test]
155 fn parse_prio() {
156 let table = SymbolScope::default();
157 assert_eq!(parse_tdim(&table, "1+2*3").unwrap(), 7.into());
158 assert_eq!(parse_tdim(&table, "1*2+3").unwrap(), 5.into());
159 }
160
161 #[test]
162 fn parse_min() {
163 let table = SymbolScope::default();
164 assert_eq!(
165 parse_tdim(&table, "min(P,S)").unwrap(),
166 TDim::Min(vec!(table.sym("P").into(), table.sym("S").into()))
167 );
168 }
169
170 #[test]
171 fn parse_inequality_0() {
172 let table = SymbolScope::default();
173 assert_eq!(
174 parse_assertion(&table, "P+S<4096").unwrap(),
175 Assertion::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim())
176 );
177 }
178}