1use super::*;
2use nom::branch::alt;
3use nom::bytes::complete::tag;
4use nom::character::complete::{alpha1, alphanumeric1, digit1, one_of};
5use nom::combinator::{all_consuming, map, map_res, recognize};
6use nom::multi::{fold, many0, separated_list0};
7use nom::sequence::{delimited, pair, preceded, separated_pair};
8use nom::{IResult, Parser};
9use nom_language::error::VerboseError;
10
11type R<'i, O> = IResult<&'i str, O, VerboseError<&'i str>>;
12
13pub fn parse_tdim(symbol_table: &SymbolScope, input: &str) -> TractResult<TDim> {
14 match all_consuming(|i| expr(symbol_table, i)).parse(input) {
15 Ok(pair) => Ok(pair.1),
16 Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
17 }
18}
19
20pub fn parse_assertion(symbol_table: &SymbolScope, input: &str) -> TractResult<Assertion> {
21 match all_consuming(|i| assertion(symbol_table, i)).parse(input) {
22 Ok(pair) => Ok(pair.1),
23 Err(e) => bail!("Failed to parse {:?}, {:?}", input, e),
24 }
25}
26
27fn assertion<'i>(s: &SymbolScope, i: &'i str) -> R<'i, Assertion> {
28 delimited(
29 spaces,
30 alt((
31 map(separated_pair(|i| expr(s, i), stag("=="), |i| expr(s, i)), |(a, b)| {
32 Assertion::Eq(a, b)
33 }),
34 map(separated_pair(|i| expr(s, i), stag("<="), |i| expr(s, i)), |(a, b)| {
35 Assertion::LTE(a, b)
36 }),
37 map(separated_pair(|i| expr(s, i), stag(">="), |i| expr(s, i)), |(a, b)| {
38 Assertion::GTE(a, b)
39 }),
40 map(separated_pair(|i| expr(s, i), stag("<"), |i| expr(s, i)), |(a, b)| {
41 Assertion::LT(a, b)
42 }),
43 map(separated_pair(|i| expr(s, i), stag(">"), |i| expr(s, i)), |(a, b)| {
44 Assertion::GT(a, b)
45 }),
46 )),
47 spaces,
48 )
49 .parse(i)
50}
51
52fn expr<'i>(symbol_table: &SymbolScope, i: &'i str) -> R<'i, TDim> {
53 broadcast(symbol_table, i)
54}
55
56fn broadcast<'i>(symbol_table: &SymbolScope, input: &'i str) -> R<'i, TDim> {
57 let s = symbol_table;
58 let (mut input, mut result) = add(s, input)?;
59 while let Ok((i, _)) = stag("#").parse(input) {
60 let (i, next) = map_res(|i| add(s, i), |v| result.clone().broadcast(v)).parse(i)?;
61 (input, result) = (i, next);
62 }
63 Ok((input, result))
64}
65
66macro_rules! bin {
67 ($name: ident, $left: expr, $right: expr, $op: expr, $builder: expr) => {
68 fn $name<'i>(symbol_table: &SymbolScope, input: &'i str) -> R<'i, TDim> {
69 let s = symbol_table;
70 let (input, result) = $left(s, input)?;
71 fold(0.., preceded(stag($op), |i| $right(s, i)), move || result.clone(), $builder)
72 .parse(input)
73 }
74 };
75}
76
77bin!(add, sub, sub, "+", |a, b| a + b);
78bin!(sub, mul, mul, "-", |a, b| a - b);
79bin!(mul, div, div, "*", |a, b| a * b);
80bin!(div, atom, |_s, i| numeric(i), "/", |a, b| a / b);
81
82fn atom<'i>(symbol_table: &SymbolScope, i: &'i str) -> R<'i, TDim> {
83 alt((
84 map(numeric, TDim::Val),
85 map(|i| func(symbol_table, "min", i), TDim::Min),
86 map(|i| func(symbol_table, "max", i), TDim::Max),
87 map(|i| func(symbol_table, "broadcast", i), TDim::Broadcast),
88 map(|i| func(symbol_table, "floor", i), |xs| xs[0].clone()),
89 map(|i| identifier(symbol_table, i), TDim::Sym),
90 map(pair(recognize(stag("-")), |i| atom(symbol_table, i)), |(_, dim)| dim * -1),
91 delimited(stag("("), |i| expr(symbol_table, i), stag(")")),
92 ))
93 .parse(i)
94}
95
96fn func<'i>(symbol_table: &SymbolScope, name: &'static str, i: &'i str) -> R<'i, Vec<TDim>> {
97 preceded(
98 stag(name),
99 delimited(stag("("), separated_list0(stag(","), |i| expr(symbol_table, i)), stag(")")),
100 )
101 .parse(i)
102}
103
104fn identifier<'i>(symbol_table: &SymbolScope, i: &'i str) -> R<'i, Symbol> {
105 map(
106 recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_"), tag(".")))))),
107 |s| symbol_table.sym(s),
108 )
109 .parse(i)
110}
111
112fn numeric(i: &str) -> R<'_, i64> {
113 map_res(digit1, std::str::FromStr::from_str).parse(i)
114}
115
116fn spaces(i: &str) -> R<'_, ()> {
117 map(many0(one_of(" \t\n\r")), |_| ()).parse(i)
118}
119
120fn spaced<'s, O, P>(it: P) -> impl Parser<&'s str, Output = O, Error = VerboseError<&'s str>>
121where
122 P: Parser<&'s str, Output = O, Error = VerboseError<&'s str>>,
123{
124 delimited(spaces, it, spaces)
125}
126
127pub(super) fn stag<'s>(
128 t: &'static str,
129) -> impl Parser<&'s str, Output = &'s str, Error = VerboseError<&'s str>> {
130 spaced(tag(t))
131}
132
133#[cfg(test)]
134mod test {
135 use super::*;
136
137 #[test]
138 fn parse_int() {
139 let table = SymbolScope::default();
140 assert_eq!(parse_tdim(&table, "12").unwrap(), TDim::Val(12));
141 assert_eq!(parse_tdim(&table, "-12").unwrap(), TDim::Val(-12));
142 }
143
144 #[test]
145 fn parse_sym() {
146 let table = SymbolScope::default();
147 assert_eq!(parse_tdim(&table, "x").unwrap(), TDim::Sym(table.sym("x")));
148 assert_eq!(
149 parse_tdim(&table, "-y").unwrap(),
150 TDim::MulInt(-1, Box::new(table.sym("y").into()))
151 );
152 }
153
154 #[test]
155 fn parse_bin() {
156 let table = SymbolScope::default();
157 assert_eq!(parse_tdim(&table, "1+2").unwrap(), 3.into());
158 assert_eq!(parse_tdim(&table, "1-2").unwrap(), (-1).into());
159 assert_eq!(parse_tdim(&table, "1*2").unwrap(), 2.into());
160 assert_eq!(parse_tdim(&table, "1/2").unwrap(), 0.into());
161 }
162
163 #[test]
164 fn parse_prio() {
165 let table = SymbolScope::default();
166 assert_eq!(parse_tdim(&table, "1+2*3").unwrap(), 7.into());
167 assert_eq!(parse_tdim(&table, "1*2+3").unwrap(), 5.into());
168 }
169
170 #[test]
171 fn parse_min() {
172 let table = SymbolScope::default();
173 assert_eq!(
174 parse_tdim(&table, "min(P,S)").unwrap(),
175 TDim::Min(vec!(table.sym("P").into(), table.sym("S").into()))
176 );
177 }
178
179 #[test]
180 fn parse_broadcast_func() {
181 let table = SymbolScope::default();
182 assert_eq!(
183 parse_tdim(&table, "broadcast(P,S)").unwrap(),
184 TDim::Broadcast(vec!(table.sym("P").into(), table.sym("S").into()))
185 );
186 }
187
188 #[test]
189 fn parse_broadcast_display_roundtrip() {
190 let table = SymbolScope::default();
191 let original = TDim::Broadcast(vec![table.sym("P").into(), table.sym("S").into()]);
192 let printed = format!("{original}");
193 let reparsed = parse_tdim(&table, &printed).unwrap();
194 assert_eq!(reparsed, original);
195 }
196
197 #[test]
198 fn parse_inequality_0() {
199 let table = SymbolScope::default();
200 assert_eq!(
201 parse_assertion(&table, "P+S<4096").unwrap(),
202 Assertion::LT(parse_tdim(&table, "P+S").unwrap(), 4096.to_dim())
203 );
204 }
205
206 #[test]
207 fn parse_dot_ids() {
208 let table = SymbolScope::default();
209 assert_eq!(parse_tdim(&table, "dot.0").unwrap(), table.sym("dot.0").into());
210 }
211
212 #[test]
213 fn parse_dot_ids_arith() {
214 let table = SymbolScope::default();
215 assert_eq!(parse_tdim(&table, "dot.0/2").unwrap(), table.sym("dot.0").to_dim() / 2);
216 }
217
218 #[test]
219 fn parse_floors() {
220 let table = SymbolScope::default();
221 assert_eq!(parse_tdim(&table, "floor(a)").unwrap(), table.sym("a").to_dim());
222 }
223}