tract_nnef/ast/
parse.rs

1use tract_core::internal::*;
2
3use nom::branch::alt;
4use nom::combinator::map;
5use nom::IResult;
6use nom::{bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*};
7
8use crate::ast::*;
9
10pub(super) fn translate_error<E: std::fmt::Debug>(e: E) -> TractError {
11    format_err!("Fail to parse NNEF document: {:?}", e)
12}
13
14#[inline(never)]
15pub fn parse_document(doc: &str) -> TractResult<Document> {
16    all_consuming(document)(doc).map(|pair| pair.1).map_err(translate_error)
17}
18
19#[inline(never)]
20pub fn parse_fragments(doc: &str) -> TractResult<Vec<FragmentDef>> {
21    all_consuming(fragments)(doc).map(|pair| pair.1).map_err(translate_error)
22}
23
24#[inline(never)]
25pub fn parse_fragment_decl(doc: &str) -> TractResult<FragmentDecl> {
26    all_consuming(fragment_decl)(doc).map(|pair| pair.1).map_err(translate_error)
27}
28
29#[inline(never)]
30pub fn parse_parameters(doc: &str) -> TractResult<Vec<Parameter>> {
31    all_consuming(parameter_list)(doc).map(|pair| pair.1).map_err(translate_error)
32}
33
34// <document> ::= <version> <extension>* <fragmentdefinition>* <graph-definition>
35fn document(i: &str) -> IResult<&str, Document> {
36    map(
37        tuple((version, many0(extension), fragments, graph_def)),
38        |(version, extension, fragments, graph_def)| Document {
39            version,
40            extension,
41            fragments,
42            graph_def,
43        },
44    )(i)
45}
46
47fn fragments(i: &str) -> IResult<&str, Vec<FragmentDef>> {
48    many0(fragment_def)(i)
49}
50
51// <version> ::= "version" <numeric-literal> ";"
52
53fn version(i: &str) -> IResult<&str, NumericLiteral> {
54    delimited(stag("version"), numeric_literal, stag(";"))(i)
55}
56
57// NNEF spec: <extension> ::= "extension" <identifier>+ ";"
58// tract accepts: <extension> ::= "extension" <identifier> <anything-but-;>";"
59fn extension(i: &str) -> IResult<&str, (Identifier, String)> {
60    delimited(
61        stag("extension"),
62        pair(spaced(identifier), map(take_until(";"), |s: &str| s.to_string())),
63        stag(";"),
64    )(i)
65}
66
67// FRAGMENT
68
69// <fragment-definition> ::= <fragment-declaration> (<body> | ";")
70fn fragment_def(i: &str) -> IResult<&str, FragmentDef> {
71    spaced(map(
72        pair(fragment_decl, alt((map(body, Some), map(stag(";"), |_| None)))),
73        |(decl, body)| FragmentDef { decl, body },
74    ))(i)
75}
76
77// <fragment-declaration> ::= "fragment" <identifier> [<generic-declaration>] "(" <parameter-list> ")" "->" "(" <result-list> ")"
78fn fragment_decl(i: &str) -> IResult<&str, FragmentDecl> {
79    let (i, _) = stag("fragment")(i)?;
80    let (i, id) = identifier(i)?;
81    let (i, generic_decl) = opt(generic_decl)(i)?;
82    let (i, _) = stag("(")(i)?;
83    let (i, parameters) = parameter_list(i)?;
84    let (i, _) = stag(")")(i)?;
85    let (i, _) = stag("->")(i)?;
86    let (i, _) = stag("(")(i)?;
87    let (i, results) = result_list(i)?;
88    let (i, _) = stag(")")(i)?;
89    Ok((i, FragmentDecl { id, parameters, results, generic_decl }))
90}
91
92// <generic-declaration> ::= "<" "?" ["=" <type-name>] ">"
93fn generic_decl(i: &str) -> IResult<&str, Option<TypeName>> {
94    let (i, _) = stag("<")(i)?;
95    let (i, _) = stag("?")(i)?;
96    let (i, name) = opt(preceded(stag("="), type_name))(i)?;
97    let (i, _) = stag(">")(i)?;
98    Ok((i, name))
99}
100
101// <parameter-list> ::= <parameter> ("," <parameter>)*
102fn parameter_list(i: &str) -> IResult<&str, Vec<Parameter>> {
103    separated_list0(stag(","), parameter)(i)
104}
105
106// <result-list> ::= <result> ("," <result>)*
107fn result_list(i: &str) -> IResult<&str, Vec<Result_>> {
108    separated_list0(stag(","), result)(i)
109}
110
111// <parameter> ::= <identifier> ":" <type-spec> ["=" <literal-expr>]
112fn parameter(i: &str) -> IResult<&str, Parameter> {
113    map(
114        pair(
115            separated_pair(identifier, stag(":"), type_spec),
116            opt(preceded(stag("="), literal_expr)),
117        ),
118        |((id, spec), lit)| Parameter { id, spec, lit, doc: None },
119    )(i)
120}
121
122// <result> ::= <identifier> ":" <type-spec>
123fn result(i: &str) -> IResult<&str, Result_> {
124    map(separated_pair(identifier, stag(":"), type_spec), |(id, spec)| Result_ { id, spec })(i)
125}
126
127fn literal_expr(i: &str) -> IResult<&str, Literal> {
128    spaced(alt((
129        literal,
130        map(delimited(stag("["), separated_list0(stag(","), literal), stag("]")), Literal::Array),
131        map(delimited(stag("("), separated_list0(stag(","), literal), stag(")")), Literal::Tuple),
132    )))(i)
133}
134
135// <type-spec> ::= <type-name> | <tensor-type-spec> | <array-type-spec> | <tuple-type-spec>
136fn type_spec(i: &str) -> IResult<&str, TypeSpec> {
137    fn non_array_type(i: &str) -> IResult<&str, TypeSpec> {
138        alt((tuple_type_spec, map(type_name, TypeSpec::Single), tensor_type_spec))(i)
139    }
140    alt((
141        (map(terminated(non_array_type, pair(stag("["), stag("]"))), |t| {
142            TypeSpec::Array(Box::new(t))
143        })),
144        non_array_type,
145    ))(i)
146}
147
148// <type-name> ::= "integer" | "scalar" | "logical" | "string" | "?"
149fn type_name(i: &str) -> IResult<&str, TypeName> {
150    spaced(alt((
151        map(tag("integer"), |_| TypeName::Integer),
152        map(tag("scalar"), |_| TypeName::Scalar),
153        map(tag("logical"), |_| TypeName::Logical),
154        map(tag("string"), |_| TypeName::String),
155        #[cfg(feature = "complex")]
156        map(tag("complex"), |_| TypeName::Complex),
157        map(tag("?"), |_| TypeName::Any),
158    )))(i)
159}
160
161// <tensor-type-spec> ::= "tensor" "<" [<type-name>] ">"
162fn tensor_type_spec(i: &str) -> IResult<&str, TypeSpec> {
163    map(delimited(pair(stag("tensor"), stag("<")), type_name, stag(">")), TypeSpec::Tensor)(i)
164}
165
166// <tuple-type-spec> ::= "(" <type-spec> ("," <type-spec>)+ ")"
167fn tuple_type_spec(i: &str) -> IResult<&str, TypeSpec> {
168    map(delimited(stag("("), separated_list0(stag(","), type_spec), stag(")")), TypeSpec::Tuple)(i)
169}
170
171// GRAPH
172
173// <graph-definition> ::= <graph-declaration> <body>
174// <graph-declaration> ::= "graph" <identifier> "(" <identifier-list> ")" "->" "(" <identifier-list> ")"
175// <identifier-list> ::= <identifier> ("," <identifier>)*
176fn graph_def(i: &str) -> IResult<&str, GraphDef> {
177    let (i, _) = stag("graph")(i)?;
178    let (i, id) = identifier(i)?;
179    let (i, _) = stag("(")(i)?;
180    let (i, parameters) = separated_list0(stag(","), identifier)(i)?;
181    let (i, _) = stag(")")(i)?;
182    let (i, _) = stag("->")(i)?;
183    let (i, _) = stag("(")(i)?;
184    let (i, results) = separated_list0(stag(","), identifier)(i)?;
185    let (i, _) = stag(")")(i)?;
186    let (i, body) = spaced(body)(i)?;
187    Ok((i, GraphDef { id, parameters, results, body }))
188}
189
190// BODY
191
192// <body> ::= "{" <assignment>+ "}"
193fn body(i: &str) -> IResult<&str, Vec<Assignment>> {
194    delimited(stag("{"), many0(assignment), stag("}"))(i)
195}
196
197// <assignment> ::= <lvalue-expr> "=" <rvalue-expr> ";"
198fn assignment(i: &str) -> IResult<&str, Assignment> {
199    spaced(terminated(
200        map(separated_pair(lvalue, stag("="), rvalue), |(left, right)| Assignment { left, right }),
201        stag(";"),
202    ))(i)
203}
204
205// <lvalue-expr> ::= <identifier> | <array-lvalue-expr> | <tuple-lvalue-expr>
206// <array-lvalue-expr> ::= "[" [<lvalue-expr> ("," <lvalue-expr>)* ] "]"
207// <tuple-lvalue-expr> ::= "(" <lvalue-expr> ("," <lvalue-expr>)+ ")" | <lvalue-expr> ("," <lvalue-expr>)+
208fn lvalue(i: &str) -> IResult<&str, LValue> {
209    fn inner_lvalue(i: &str) -> IResult<&str, LValue> {
210        alt((
211            map(
212                delimited(stag("["), separated_list0(stag(","), inner_lvalue), stag("]")),
213                LValue::Array,
214            ),
215            map(
216                delimited(stag("("), separated_list0(stag(","), inner_lvalue), stag(")")),
217                LValue::Tuple,
218            ),
219            map(spaced(identifier), LValue::Identifier),
220        ))(i)
221    }
222
223    map(separated_list0(stag(","), inner_lvalue), |mut iv| {
224        if iv.len() == 1 {
225            iv.remove(0)
226        } else {
227            LValue::Tuple(iv)
228        }
229    })(i)
230}
231
232// <invocation> ::= <identifier> ["<" <type-name> ">"] "(" <argument-list> ")"
233fn invocation(i: &str) -> IResult<&str, Invocation> {
234    let (i, id) = spaced(identifier)(i)?;
235    let (i, generic_type_name) = opt(delimited(stag("<"), type_name, stag(">")))(i)?;
236    let (i, _) = stag("(")(i)?;
237    let (i, arguments) = argument_list(i)?;
238    let (i, _) = stag(")")(i)?;
239    Ok((i, Invocation { id, generic_type_name, arguments }))
240}
241
242// <argument-list> ::= <argument> ("," <argument>)*
243fn argument_list(i: &str) -> IResult<&str, Vec<Argument>> {
244    separated_list0(stag(","), argument)(i)
245}
246
247// <argument> ::= <rvalue-expr> | <identifier> "=" <rvalue-expr>
248fn argument(i: &str) -> IResult<&str, Argument> {
249    spaced(map(pair(opt(terminated(identifier, stag("="))), rvalue), |(id, rvalue)| Argument {
250        id,
251        rvalue,
252    }))(i)
253}
254
255//<rvalue-expr> ::= <identifier> | <literal> | <binary-expr> | <unary-expr> | <paren-expr>
256//                  | <array-rvalue-expr> | <tuple-rvalue-expr> | <subscript-expr> | <if-else-expr>
257//                  | <comprehension-expr> | <builtin-expr> | <invocation>
258fn rvalue(i: &str) -> IResult<&str, RValue> {
259    fn atom(i: &str) -> IResult<&str, RValue> {
260        spaced(alt((
261            map(invocation, RValue::Invocation),
262            map(literal, RValue::Literal),
263            map(identifier, RValue::Identifier),
264            map(pair(spaced(recognize(one_of("+-!"))), rvalue), |(op, rv)| {
265                RValue::Unary(op.into(), Box::new(rv))
266            }),
267            map(delimited(tag("("), separated_list0(stag(","), rvalue), tag(")")), |mut rvs| {
268                if rvs.len() == 1 {
269                    rvs.remove(0)
270                } else {
271                    RValue::Tuple(rvs)
272                }
273            }),
274            map(comprehension_expr, |c| RValue::Comprehension(Box::new(c))),
275            map(delimited(tag("["), separated_list0(stag(","), rvalue), tag("]")), |rvs| {
276                RValue::Array(rvs)
277            }),
278        )))(i)
279    }
280    macro_rules! bin {
281        ($name:ident, $operand: ident, $operator: expr) => {
282            fn $name(i: &str) -> IResult<&str, RValue> {
283                let (i, init) = $operand(i)?;
284                fold_many0(
285                    pair($operator, $operand),
286                    move || init.clone(),
287                    |left, (op, right)| {
288                        RValue::Binary(Box::new(left), op.to_string(), Box::new(right))
289                    },
290                )(i)
291            }
292        };
293    }
294
295    // <subscript-expr> ::= <rvalue-expr> "[" (<rvalue-expr> | [<rvalue-expr>] ":" [<rvalue-expr>]) "]"
296    fn sub(i: &str) -> IResult<&str, RValue> {
297        alt((
298            map(
299                pair(
300                    atom,
301                    delimited(
302                        stag("["),
303                        alt((
304                            map(separated_pair(opt(rvalue), stag(":"), opt(rvalue)), |(a, b)| {
305                                Subscript::Range(a, b)
306                            }),
307                            map(rvalue, Subscript::Single),
308                        )),
309                        stag("]"),
310                    ),
311                ),
312                |(rv, range)| RValue::Subscript(Box::new(rv), Box::new(range)),
313            ),
314            atom,
315        ))(i)
316    }
317
318    bin!(exp, sub, tag("^"));
319    bin!(mul, exp, one_of("*/"));
320    bin!(add, mul, one_of("+-"));
321    bin!(comp, add, alt((tag("=="), tag("!="), tag("<"), tag(">"), tag("<="), tag(">="))));
322    bin!(boolean, comp, alt((tag("||"), tag("&&"))));
323    bin!(in_for, boolean, tag("in"));
324
325    // <if-else-expr> ::= <rvalue-expr> "if" <rvalue-expr> "else" <rvalue-expr>
326    fn ite(i: &str) -> IResult<&str, RValue> {
327        let (i, leftmost) = in_for(i)?;
328        let (i, _) = space_and_comments(i)?;
329        if i.starts_with("if") {
330            let (i, _) = stag("if")(i)?;
331            let (i, cond) = in_for(i)?;
332            let (i, _) = stag("else")(i)?;
333            let (i, otherwise) = in_for(i)?;
334            Ok((i, RValue::IfThenElse(Box::new(IfThenElse { cond, then: leftmost, otherwise }))))
335        } else {
336            Ok((i, leftmost))
337        }
338    }
339
340    ite(i)
341}
342
343// <comprehension-expr> ::= "[" "for" <loop-iter-list> ["if" <rvalue-expr>] "yield" <rvalue-expr> "]"
344fn comprehension_expr(i: &str) -> IResult<&str, Comprehension> {
345    delimited(
346        pair(stag("["), stag("for")),
347        map(separated_pair(loop_iters, stag("yield"), rvalue), |(loop_iters, yields)| {
348            Comprehension { loop_iters, filter: None, yields }
349        }),
350        stag("]"),
351    )(i)
352}
353
354// <loop-iter> ::= <identifier> "in" <rvalue-expr>
355// <loop-iter-list> ::= <loop-iter> ("," <loop-iter>)*
356fn loop_iters(i: &str) -> IResult<&str, Vec<(Identifier, RValue)>> {
357    separated_list0(stag(","), separated_pair(identifier, stag("in"), rvalue))(i)
358}
359
360// TERMINALS
361
362// identifier: identifiers must consist of the following ASCII characters: _, [a-z], [A-Z], [0-9].
363// The identifier must not start with a digit.
364pub(super) fn identifier(i: &str) -> IResult<&str, Identifier> {
365    alt((escaped_identifier, direct_identifier))(i)
366}
367
368pub(super) fn direct_identifier(i: &str) -> IResult<&str, Identifier> {
369    map(
370        recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))),
371        Identifier::from,
372    )(i)
373}
374
375pub(super) fn escaped_identifier(i: &str) -> IResult<&str, Identifier> {
376    map(preceded(tag("i"), string_literal), Identifier)(i)
377}
378
379// <literal> ::= <numeric-literal> | <string-literal> | <logical-literal>
380fn literal(i: &str) -> IResult<&str, Literal> {
381    spaced(alt((
382        map(numeric_literal, Literal::Numeric),
383        map(string_literal, Literal::String),
384        map(logical_literal, Literal::Logical),
385    )))(i)
386}
387
388pub(super) fn numeric_literal(i: &str) -> IResult<&str, String> {
389    fn exp_part(i: &str) -> IResult<&str, &str> {
390        recognize(tuple((one_of("eE"), opt(tag("-")), digit1)))(i)
391    }
392    fn frac_part(i: &str) -> IResult<&str, &str> {
393        recognize(tuple((tag("."), digit0)))(i)
394    }
395    spaced(map(
396        recognize(tuple((opt(tag("-")), alt((digit1, tag("inf"))), opt(frac_part), opt(exp_part)))),
397        |s: &str| s.to_owned(),
398    ))(i)
399}
400
401fn string_literal(i: &str) -> IResult<&str, String> {
402    fn inner(i: &str) -> IResult<&str, String> {
403        map(
404            many0(alt((
405                preceded(tag("\\"), nom::character::complete::anychar),
406                nom::character::complete::none_of("\\\"'"),
407            ))),
408            |v: Vec<char>| v.into_iter().collect(),
409        )(i)
410    }
411    map(alt((delimited(tag("'"), inner, tag("'")), delimited(tag("\""), inner, tag("\"")))), |s| s)(
412        i,
413    )
414}
415
416pub(super) fn logical_literal(i: &str) -> IResult<&str, bool> {
417    spaced(alt((map(tag("true"), |_| true), map(tag("false"), |_| false))))(i)
418}
419
420// SPACES
421
422fn space_and_comments(i: &str) -> IResult<&str, ()> {
423    map(
424        many0(alt((
425            recognize(one_of(" \t\n\r")),
426            recognize(tuple((tag("#"), many0(none_of("\r\n"))))),
427        ))),
428        |_| (),
429    )(i)
430}
431
432fn spaced<'s, O, F>(it: F) -> impl FnMut(&'s str) -> IResult<&'s str, O>
433where
434    F: FnMut(&'s str) -> IResult<&'s str, O>,
435{
436    delimited(space_and_comments, it, space_and_comments)
437}
438
439pub(super) fn stag<'s>(t: &'static str) -> impl FnMut(&'s str) -> IResult<&'s str, &'s str> {
440    spaced(tag(t))
441}
442
443#[cfg(test)]
444mod test {
445    use super::*;
446    use TypeName::*;
447    use TypeSpec::*;
448
449    fn p<'s, P, O, E>(parser: P, i: &'s str) -> O
450    where
451        O: std::fmt::Debug,
452        P: Fn(&'s str) -> IResult<&'s str, O, E>,
453        E: nom::error::ParseError<&'s str> + std::fmt::Debug,
454    {
455        let res = all_consuming(parser)(i).unwrap();
456        res.1
457    }
458
459    fn param(s: impl Into<std::string::String>, t: TypeSpec) -> Parameter {
460        Parameter { id: Identifier(s.into()), spec: t, lit: None, doc: None }
461    }
462
463    fn result(s: impl Into<std::string::String>, t: TypeSpec) -> Result_ {
464        Result_ { id: Identifier(s.into()), spec: t }
465    }
466
467    #[test]
468    fn test_type_spec() {
469        assert_eq!(p(type_spec, "scalar"), Single(Scalar));
470        assert_eq!(p(type_spec, "scalar[]"), Array(Box::new(Single(Scalar))));
471        assert_eq!(p(type_spec, "tensor<scalar>[]"), Array(Box::new(Tensor(TypeName::Scalar))));
472        assert_eq!(
473            p(type_spec, "(scalar,scalar[],tensor<scalar>)"),
474            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
475        );
476        assert_eq!(p(type_spec, "tensor<?>[]"), Array(Box::new(Tensor(TypeName::Any))));
477        assert_eq!(p(type_spec, "scalar[ ]"), Array(Box::new(Single(Scalar))));
478        assert_eq!(
479            p(type_spec, " ( scalar , scalar [ ] , tensor < scalar > ) "),
480            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
481        );
482        #[cfg(feature = "complex")]
483        assert_eq!(p(type_spec, "tensor<complex>[]"), Array(Box::new(Tensor(TypeName::Complex))));
484    }
485
486    #[test]
487    fn test_fragment_decl_fizz() {
488        let parsed = p(
489            fragment_decl,
490            "fragment fizz<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
491        );
492        assert_eq!(
493            parsed,
494            FragmentDecl {
495                id: "fizz".into(),
496                generic_decl: Some(Some(Scalar)),
497                parameters: vec!(param("shape", Array(Box::new(Single(Integer)))),),
498                results: vec!(result("output", Tensor(Any))),
499            }
500        );
501    }
502
503    #[test]
504    fn test_fragment_decl_logarithmic_quantize() {
505        let parsed = p(fragment_decl,
506                           "fragment logarithmic_quantize(x: tensor<scalar>, max: tensor<scalar>, bits: integer ) -> ( y: tensor<scalar> )"
507                          );
508        assert_eq!(
509            parsed,
510            FragmentDecl {
511                id: "logarithmic_quantize".into(),
512                generic_decl: None,
513                parameters: vec!(
514                    param("x", Tensor(Scalar)),
515                    param("max", Tensor(Scalar)),
516                    param("bits", Single(Integer))
517                ),
518                results: vec!(result("y", Tensor(Scalar))),
519            }
520        );
521    }
522
523    #[test]
524    fn test_fragment_decl_external() {
525        p(
526            fragment_decl,
527            "fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
528        );
529    }
530
531    #[test]
532    fn test_fragment_reshape() {
533        p(fragments, "fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );");
534    }
535
536    #[test]
537    fn test_fragment_conv() {
538        p(
539            fragments,
540            r#"
541            fragment conv(
542                input: tensor<scalar>,
543                filter: tensor<scalar>,
544                bias: tensor<scalar> = 0.0,
545                border: string = 'constant',
546                padding: (integer,integer)[] = [],
547                stride: integer[] = [],
548                dilation: integer[] = [],
549                groups: integer = 1 )
550            -> ( output: tensor<scalar> );
551            "#,
552        );
553    }
554
555    #[test]
556    fn test_fragment_local_response_normalization() {
557        p(
558            fragments,
559            r#"
560            fragment local_response_normalization(
561                input: tensor<scalar>,
562                size: integer[],
563                alpha: scalar = 1.0,
564                beta: scalar = 0.5,
565                bias: scalar = 1.0 )
566            -> ( output: tensor<scalar> )
567            {
568                sigma = bias + alpha * box(sqr(input), size = size, normalize = true);
569                output = input / (sigma ^ beta);
570            }
571            "#,
572        );
573    }
574
575    #[test]
576    fn test_batch_normalization() {
577        p(
578            fragments,
579            r#"
580            fragment batch_normalization( input: tensor<scalar>, mean: tensor<scalar>, variance: tensor<scalar>, offset: tensor<scalar>, scale: tensor<scalar>, epsilon: scalar )
581            -> ( output: tensor<scalar> )
582            {
583                output = offset + scale * (input - mean) / sqrt(variance + epsilon);
584            }
585            "#,
586        );
587    }
588
589    #[test]
590    fn test_avg_roi_align() {
591        p(
592            fragments,
593            r#"
594                fragment avg_roi_align(
595                    input: tensor<scalar>,
596                    rois: tensor<scalar>,
597                    batch_index: tensor<integer>,
598                    output_size: integer[],
599                    sampling_rate: integer[],
600                    resize_method: string = 'symmetric' )
601                -> ( output: tensor<scalar> )
602                {
603                    size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];
604                    resized = roi_resample(input, rois, batch_index, output_size = size,
605                                         method = resize_method);
606                    output = avg_pool(resized, size = sampling_rate, stride = sampling_rate);
607                }
608            "#,
609        );
610    }
611
612    #[test]
613    fn test_min_max_linear_quantize() {
614        p(
615            fragments,
616            r#"
617                fragment min_max_linear_quantize(
618                    x: tensor<scalar>,
619                    min: tensor<scalar>,
620                    max: tensor<scalar>,
621                    bits: integer,
622                    signed: logical,
623                    symmetric: logical )
624                -> ( y: tensor<scalar> )
625                {
626                    r = scalar(2 ^ bits - 1 - integer(signed && symmetric));
627                    z = clamp(x, min, max);
628                    p = scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0);
629                    q = round((z - min) / (max - min) * r) - p;
630                    y = (q + p) / r * (max - min) + min;
631}
632            "#,
633        );
634    }
635
636    #[test]
637    fn test_numeric() {
638        p(numeric_literal, "12.0");
639    }
640
641    #[test]
642    fn test_string() {
643        assert_eq!(p(string_literal, r#""""#), "");
644        assert_eq!(p(string_literal, r#""foo""#), "foo");
645        assert_eq!(p(string_literal, r#"''"#), "");
646        assert_eq!(p(string_literal, r#"'foo'"#), "foo");
647
648        assert_eq!(p(string_literal, r"'f\oo'"), "foo");
649        assert_eq!(p(string_literal, r"'f\'oo'"), "f'oo");
650        assert_eq!(p(string_literal, r#"'f\"oo'"#), "f\"oo");
651    }
652
653    #[test]
654    fn test_identifier() {
655        p(identifier, "foo");
656        assert!(identifier("1").is_err());
657        assert!(identifier("1foo").is_err());
658    }
659
660    #[test]
661    fn test_spacing() {
662        p(space_and_comments, "");
663        p(space_and_comments, "\n");
664        p(space_and_comments, "#comment\n");
665        p(space_and_comments, "#boum");
666    }
667
668    #[test]
669    fn test_spaced() {
670        assert!(spaced(identifier)("foo").is_ok());
671        assert!(spaced(identifier)(" foo ").is_ok());
672        assert!(many1(spaced(identifier))(" foo bar ").is_ok());
673        assert_eq!(
674            many1(spaced(identifier))(" foo bar\n").unwrap().1,
675            &[Identifier("foo".to_string()), Identifier("bar".to_string())]
676        );
677        assert_eq!(
678            many1(spaced(identifier))(" foo # bar\n").unwrap().1,
679            &[Identifier("foo".to_string())]
680        );
681        assert_eq!(
682            many1(spaced(identifier))(" foo # bar\nbaz").unwrap().1,
683            &[Identifier("foo".to_string()), Identifier("baz".to_string())]
684        );
685    }
686
687    #[test]
688    fn test_document() {
689        assert!(document("version 1.0; graph foo() -> () {}").is_ok());
690    }
691
692    #[test]
693    fn test_version() {
694        p(version, "version 1.0;");
695    }
696
697    #[test]
698    fn test_body() {
699        p(body, "{}");
700        p(body, "{foo=bar;}");
701    }
702
703    #[test]
704    fn test_lvalue() {
705        p(lvalue, "foo");
706        p(lvalue, "foo,bar");
707        p(lvalue, "foo , bar");
708        p(lvalue, "(foo,bar)");
709    }
710
711    #[test]
712    fn test_graph_def() {
713        p(graph_def, "graph foo() -> () {}");
714    }
715
716    #[test]
717    fn test_assignment() {
718        p(assignment, "input = external(12);");
719        p(assignment, "input = external(shape = [1, 3, 224, 224]);");
720        p(assignment, "sigma = bias + alpha * box(sqr(input), size = size, normalize = true);");
721        p(assignment, "output = offset + scale * (input - mean) / sqrt(variance + epsilon);");
722        p(
723            assignment,
724            "size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];",
725        );
726        p(assignment, "r = scalar(2 ^ bits - 1 - integer(signed && symmetric));");
727        p(assignment, "output, index = max_pool_with_index(input, size = size, border = border, padding = padding, stride = stride, dilation = dilation);");
728    }
729
730    #[test]
731    fn test_invocation() {
732        p(invocation, "external(12)");
733        p(invocation, "sqrt(var + eps)");
734    }
735
736    #[test]
737    fn test_arguments() {
738        p(argument, "2");
739        p(argument, "12");
740        p(argument, "shape = [1, 3, 224, 224]");
741    }
742
743    #[test]
744    fn test_rvalue() {
745        p(rvalue, "12");
746        p(rvalue, "(0, 0)");
747        p(rvalue, "x ^ 2.0");
748        p(rvalue, "1+2");
749        p(rvalue, "1+sqrt(var)");
750        p(rvalue, "1+sqrt(var+eps)");
751        p(rvalue, "1 + sqrt(var + eps)");
752        p(rvalue, "[for i in range_of(output_size) yield output_size[i] * sampling_rate[i]]");
753        p(rvalue, "scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0)");
754    }
755
756    #[test]
757    fn test_comprehenion() {
758        p(comprehension_expr, "[for i in range_of(output_size) yield output_size * sampling_rate]");
759    }
760
761    #[test]
762    fn test_freeze() {
763        p(
764            document,
765            r#"
766version 1.0;
767
768graph y( x, s, bias ) -> ( y ) {
769  x = external<scalar>(shape = [1, 2, 1, 3]);
770  s = external<scalar>(shape = [2]);
771  bias = external<scalar>(shape = [2]);
772  y = add(
773        mul(
774            mul(
775                sub(
776                    x,
777                    mul(
778                        0.33333334,
779                        sum_reduce(
780                            x,
781                            axes = [0, 2, 3]
782                        )
783                    )
784                ),
785                rsqrt(
786                    add(
787                        0.00001,
788                        mul(
789                            0.33333334,
790                            sum_reduce(
791                                square(
792                                    sub(
793                                        x,
794                                        mul(
795                                            0.33333334,
796                                            sum_reduce(
797                                                x,
798                                                axes = [0, 2, 3]
799                                            )
800                                        )
801                                    )
802                                ),
803                                axes = [0, 2, 3]
804                            )
805                        )
806                    )
807                )
808            ),
809            unsqueeze(
810                unsqueeze(
811                    unsqueeze(
812                        s,
813                        axes = [0]
814                    ),
815                axes = [2]
816                ),
817            axes = [2]
818            )
819        ),
820        unsqueeze(
821            unsqueeze(
822                unsqueeze(
823                    bias,
824                    axes = [0]
825                ),
826                axes = [2]
827            ),
828            axes = [2]
829        )
830    );
831}
832
833"#,
834        );
835    }
836
837    #[test]
838    fn test_fragments() {
839        p(
840            fragments,
841            r#"
842            fragment add( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
843            fragment sub( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
844            "#,
845        );
846    }
847}