1use tract_core::internal::*;
2
3use nom::branch::alt;
4use nom::combinator::map;
5use nom::IResult;
6use nom::{bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*};
7
8use crate::ast::*;
9
10pub(super) fn translate_error<E: std::fmt::Debug>(e: E) -> TractError {
11 format_err!("Fail to parse NNEF document: {:?}", e)
12}
13
14#[inline(never)]
15pub fn parse_document(doc: &str) -> TractResult<Document> {
16 all_consuming(document)(doc).map(|pair| pair.1).map_err(translate_error)
17}
18
19#[inline(never)]
20pub fn parse_fragments(doc: &str) -> TractResult<Vec<FragmentDef>> {
21 all_consuming(fragments)(doc).map(|pair| pair.1).map_err(translate_error)
22}
23
24#[inline(never)]
25pub fn parse_fragment_decl(doc: &str) -> TractResult<FragmentDecl> {
26 all_consuming(fragment_decl)(doc).map(|pair| pair.1).map_err(translate_error)
27}
28
29#[inline(never)]
30pub fn parse_parameters(doc: &str) -> TractResult<Vec<Parameter>> {
31 all_consuming(parameter_list)(doc).map(|pair| pair.1).map_err(translate_error)
32}
33
34fn document(i: &str) -> IResult<&str, Document> {
36 map(
37 tuple((version, many0(extension), fragments, graph_def)),
38 |(version, extension, fragments, graph_def)| Document {
39 version,
40 extension,
41 fragments,
42 graph_def,
43 },
44 )(i)
45}
46
47fn fragments(i: &str) -> IResult<&str, Vec<FragmentDef>> {
48 many0(fragment_def)(i)
49}
50
51fn version(i: &str) -> IResult<&str, NumericLiteral> {
54 delimited(stag("version"), numeric_literal, stag(";"))(i)
55}
56
57fn extension(i: &str) -> IResult<&str, (Identifier, String)> {
60 delimited(
61 stag("extension"),
62 pair(spaced(identifier), map(take_until(";"), |s: &str| s.to_string())),
63 stag(";"),
64 )(i)
65}
66
67fn fragment_def(i: &str) -> IResult<&str, FragmentDef> {
71 spaced(map(
72 pair(fragment_decl, alt((map(body, Some), map(stag(";"), |_| None)))),
73 |(decl, body)| FragmentDef { decl, body },
74 ))(i)
75}
76
77fn fragment_decl(i: &str) -> IResult<&str, FragmentDecl> {
79 let (i, _) = stag("fragment")(i)?;
80 let (i, id) = identifier(i)?;
81 let (i, generic_decl) = opt(generic_decl)(i)?;
82 let (i, _) = stag("(")(i)?;
83 let (i, parameters) = parameter_list(i)?;
84 let (i, _) = stag(")")(i)?;
85 let (i, _) = stag("->")(i)?;
86 let (i, _) = stag("(")(i)?;
87 let (i, results) = result_list(i)?;
88 let (i, _) = stag(")")(i)?;
89 Ok((i, FragmentDecl { id, parameters, results, generic_decl }))
90}
91
92fn generic_decl(i: &str) -> IResult<&str, Option<TypeName>> {
94 let (i, _) = stag("<")(i)?;
95 let (i, _) = stag("?")(i)?;
96 let (i, name) = opt(preceded(stag("="), type_name))(i)?;
97 let (i, _) = stag(">")(i)?;
98 Ok((i, name))
99}
100
101fn parameter_list(i: &str) -> IResult<&str, Vec<Parameter>> {
103 separated_list0(stag(","), parameter)(i)
104}
105
106fn result_list(i: &str) -> IResult<&str, Vec<Result_>> {
108 separated_list0(stag(","), result)(i)
109}
110
111fn parameter(i: &str) -> IResult<&str, Parameter> {
113 map(
114 pair(
115 separated_pair(identifier, stag(":"), type_spec),
116 opt(preceded(stag("="), literal_expr)),
117 ),
118 |((id, spec), lit)| Parameter { id, spec, lit, doc: None },
119 )(i)
120}
121
122fn result(i: &str) -> IResult<&str, Result_> {
124 map(separated_pair(identifier, stag(":"), type_spec), |(id, spec)| Result_ { id, spec })(i)
125}
126
127fn literal_expr(i: &str) -> IResult<&str, Literal> {
128 spaced(alt((
129 literal,
130 map(delimited(stag("["), separated_list0(stag(","), literal), stag("]")), Literal::Array),
131 map(delimited(stag("("), separated_list0(stag(","), literal), stag(")")), Literal::Tuple),
132 )))(i)
133}
134
135fn type_spec(i: &str) -> IResult<&str, TypeSpec> {
137 fn non_array_type(i: &str) -> IResult<&str, TypeSpec> {
138 alt((tuple_type_spec, map(type_name, TypeSpec::Single), tensor_type_spec))(i)
139 }
140 alt((
141 (map(terminated(non_array_type, pair(stag("["), stag("]"))), |t| {
142 TypeSpec::Array(Box::new(t))
143 })),
144 non_array_type,
145 ))(i)
146}
147
148fn type_name(i: &str) -> IResult<&str, TypeName> {
150 spaced(alt((
151 map(tag("integer"), |_| TypeName::Integer),
152 map(tag("scalar"), |_| TypeName::Scalar),
153 map(tag("logical"), |_| TypeName::Logical),
154 map(tag("string"), |_| TypeName::String),
155 #[cfg(feature = "complex")]
156 map(tag("complex"), |_| TypeName::Complex),
157 map(tag("?"), |_| TypeName::Any),
158 )))(i)
159}
160
161fn tensor_type_spec(i: &str) -> IResult<&str, TypeSpec> {
163 map(delimited(pair(stag("tensor"), stag("<")), type_name, stag(">")), TypeSpec::Tensor)(i)
164}
165
166fn tuple_type_spec(i: &str) -> IResult<&str, TypeSpec> {
168 map(delimited(stag("("), separated_list0(stag(","), type_spec), stag(")")), TypeSpec::Tuple)(i)
169}
170
171fn graph_def(i: &str) -> IResult<&str, GraphDef> {
177 let (i, _) = stag("graph")(i)?;
178 let (i, id) = identifier(i)?;
179 let (i, _) = stag("(")(i)?;
180 let (i, parameters) = separated_list0(stag(","), identifier)(i)?;
181 let (i, _) = stag(")")(i)?;
182 let (i, _) = stag("->")(i)?;
183 let (i, _) = stag("(")(i)?;
184 let (i, results) = separated_list0(stag(","), identifier)(i)?;
185 let (i, _) = stag(")")(i)?;
186 let (i, body) = spaced(body)(i)?;
187 Ok((i, GraphDef { id, parameters, results, body }))
188}
189
190fn body(i: &str) -> IResult<&str, Vec<Assignment>> {
194 delimited(stag("{"), many0(assignment), stag("}"))(i)
195}
196
197fn assignment(i: &str) -> IResult<&str, Assignment> {
199 spaced(terminated(
200 map(separated_pair(lvalue, stag("="), rvalue), |(left, right)| Assignment { left, right }),
201 stag(";"),
202 ))(i)
203}
204
205fn lvalue(i: &str) -> IResult<&str, LValue> {
209 fn inner_lvalue(i: &str) -> IResult<&str, LValue> {
210 alt((
211 map(
212 delimited(stag("["), separated_list0(stag(","), inner_lvalue), stag("]")),
213 LValue::Array,
214 ),
215 map(
216 delimited(stag("("), separated_list0(stag(","), inner_lvalue), stag(")")),
217 LValue::Tuple,
218 ),
219 map(spaced(identifier), LValue::Identifier),
220 ))(i)
221 }
222
223 map(separated_list0(stag(","), inner_lvalue), |mut iv| {
224 if iv.len() == 1 {
225 iv.remove(0)
226 } else {
227 LValue::Tuple(iv)
228 }
229 })(i)
230}
231
232fn invocation(i: &str) -> IResult<&str, Invocation> {
234 let (i, id) = spaced(identifier)(i)?;
235 let (i, generic_type_name) = opt(delimited(stag("<"), type_name, stag(">")))(i)?;
236 let (i, _) = stag("(")(i)?;
237 let (i, arguments) = argument_list(i)?;
238 let (i, _) = stag(")")(i)?;
239 Ok((i, Invocation { id, generic_type_name, arguments }))
240}
241
242fn argument_list(i: &str) -> IResult<&str, Vec<Argument>> {
244 separated_list0(stag(","), argument)(i)
245}
246
247fn argument(i: &str) -> IResult<&str, Argument> {
249 spaced(map(pair(opt(terminated(identifier, stag("="))), rvalue), |(id, rvalue)| Argument {
250 id,
251 rvalue,
252 }))(i)
253}
254
255fn rvalue(i: &str) -> IResult<&str, RValue> {
259 fn atom(i: &str) -> IResult<&str, RValue> {
260 spaced(alt((
261 map(invocation, RValue::Invocation),
262 map(literal, RValue::Literal),
263 map(identifier, RValue::Identifier),
264 map(pair(spaced(recognize(one_of("+-!"))), rvalue), |(op, rv)| {
265 RValue::Unary(op.into(), Box::new(rv))
266 }),
267 map(delimited(tag("("), separated_list0(stag(","), rvalue), tag(")")), |mut rvs| {
268 if rvs.len() == 1 {
269 rvs.remove(0)
270 } else {
271 RValue::Tuple(rvs)
272 }
273 }),
274 map(comprehension_expr, |c| RValue::Comprehension(Box::new(c))),
275 map(delimited(tag("["), separated_list0(stag(","), rvalue), tag("]")), |rvs| {
276 RValue::Array(rvs)
277 }),
278 )))(i)
279 }
280 macro_rules! bin {
281 ($name:ident, $operand: ident, $operator: expr) => {
282 fn $name(i: &str) -> IResult<&str, RValue> {
283 let (i, init) = $operand(i)?;
284 fold_many0(
285 pair($operator, $operand),
286 move || init.clone(),
287 |left, (op, right)| {
288 RValue::Binary(Box::new(left), op.to_string(), Box::new(right))
289 },
290 )(i)
291 }
292 };
293 }
294
295 fn sub(i: &str) -> IResult<&str, RValue> {
297 alt((
298 map(
299 pair(
300 atom,
301 delimited(
302 stag("["),
303 alt((
304 map(separated_pair(opt(rvalue), stag(":"), opt(rvalue)), |(a, b)| {
305 Subscript::Range(a, b)
306 }),
307 map(rvalue, Subscript::Single),
308 )),
309 stag("]"),
310 ),
311 ),
312 |(rv, range)| RValue::Subscript(Box::new(rv), Box::new(range)),
313 ),
314 atom,
315 ))(i)
316 }
317
318 bin!(exp, sub, tag("^"));
319 bin!(mul, exp, one_of("*/"));
320 bin!(add, mul, one_of("+-"));
321 bin!(comp, add, alt((tag("=="), tag("!="), tag("<"), tag(">"), tag("<="), tag(">="))));
322 bin!(boolean, comp, alt((tag("||"), tag("&&"))));
323 bin!(in_for, boolean, tag("in"));
324
325 fn ite(i: &str) -> IResult<&str, RValue> {
327 let (i, leftmost) = in_for(i)?;
328 let (i, _) = space_and_comments(i)?;
329 if i.starts_with("if") {
330 let (i, _) = stag("if")(i)?;
331 let (i, cond) = in_for(i)?;
332 let (i, _) = stag("else")(i)?;
333 let (i, otherwise) = in_for(i)?;
334 Ok((i, RValue::IfThenElse(Box::new(IfThenElse { cond, then: leftmost, otherwise }))))
335 } else {
336 Ok((i, leftmost))
337 }
338 }
339
340 ite(i)
341}
342
343fn comprehension_expr(i: &str) -> IResult<&str, Comprehension> {
345 delimited(
346 pair(stag("["), stag("for")),
347 map(separated_pair(loop_iters, stag("yield"), rvalue), |(loop_iters, yields)| {
348 Comprehension { loop_iters, filter: None, yields }
349 }),
350 stag("]"),
351 )(i)
352}
353
354fn loop_iters(i: &str) -> IResult<&str, Vec<(Identifier, RValue)>> {
357 separated_list0(stag(","), separated_pair(identifier, stag("in"), rvalue))(i)
358}
359
360pub(super) fn identifier(i: &str) -> IResult<&str, Identifier> {
365 alt((escaped_identifier, direct_identifier))(i)
366}
367
368pub(super) fn direct_identifier(i: &str) -> IResult<&str, Identifier> {
369 map(
370 recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))),
371 Identifier::from,
372 )(i)
373}
374
375pub(super) fn escaped_identifier(i: &str) -> IResult<&str, Identifier> {
376 map(preceded(tag("i"), string_literal), Identifier)(i)
377}
378
379fn literal(i: &str) -> IResult<&str, Literal> {
381 spaced(alt((
382 map(numeric_literal, Literal::Numeric),
383 map(string_literal, Literal::String),
384 map(logical_literal, Literal::Logical),
385 )))(i)
386}
387
388pub(super) fn numeric_literal(i: &str) -> IResult<&str, String> {
389 fn exp_part(i: &str) -> IResult<&str, &str> {
390 recognize(tuple((one_of("eE"), opt(tag("-")), digit1)))(i)
391 }
392 fn frac_part(i: &str) -> IResult<&str, &str> {
393 recognize(tuple((tag("."), digit0)))(i)
394 }
395 spaced(map(
396 recognize(tuple((opt(tag("-")), alt((digit1, tag("inf"))), opt(frac_part), opt(exp_part)))),
397 |s: &str| s.to_owned(),
398 ))(i)
399}
400
401fn string_literal(i: &str) -> IResult<&str, String> {
402 fn inner(i: &str) -> IResult<&str, String> {
403 map(
404 many0(alt((
405 preceded(tag("\\"), nom::character::complete::anychar),
406 nom::character::complete::none_of("\\\"'"),
407 ))),
408 |v: Vec<char>| v.into_iter().collect(),
409 )(i)
410 }
411 map(alt((delimited(tag("'"), inner, tag("'")), delimited(tag("\""), inner, tag("\"")))), |s| s)(
412 i,
413 )
414}
415
416pub(super) fn logical_literal(i: &str) -> IResult<&str, bool> {
417 spaced(alt((map(tag("true"), |_| true), map(tag("false"), |_| false))))(i)
418}
419
420fn space_and_comments(i: &str) -> IResult<&str, ()> {
423 map(
424 many0(alt((
425 recognize(one_of(" \t\n\r")),
426 recognize(tuple((tag("#"), many0(none_of("\r\n"))))),
427 ))),
428 |_| (),
429 )(i)
430}
431
432fn spaced<'s, O, F>(it: F) -> impl FnMut(&'s str) -> IResult<&'s str, O>
433where
434 F: FnMut(&'s str) -> IResult<&'s str, O>,
435{
436 delimited(space_and_comments, it, space_and_comments)
437}
438
439pub(super) fn stag<'s>(t: &'static str) -> impl FnMut(&'s str) -> IResult<&'s str, &'s str> {
440 spaced(tag(t))
441}
442
443#[cfg(test)]
444mod test {
445 use super::*;
446 use TypeName::*;
447 use TypeSpec::*;
448
449 fn p<'s, P, O, E>(parser: P, i: &'s str) -> O
450 where
451 O: std::fmt::Debug,
452 P: Fn(&'s str) -> IResult<&'s str, O, E>,
453 E: nom::error::ParseError<&'s str> + std::fmt::Debug,
454 {
455 let res = all_consuming(parser)(i).unwrap();
456 res.1
457 }
458
459 fn param(s: impl Into<std::string::String>, t: TypeSpec) -> Parameter {
460 Parameter { id: Identifier(s.into()), spec: t, lit: None, doc: None }
461 }
462
463 fn result(s: impl Into<std::string::String>, t: TypeSpec) -> Result_ {
464 Result_ { id: Identifier(s.into()), spec: t }
465 }
466
467 #[test]
468 fn test_type_spec() {
469 assert_eq!(p(type_spec, "scalar"), Single(Scalar));
470 assert_eq!(p(type_spec, "scalar[]"), Array(Box::new(Single(Scalar))));
471 assert_eq!(p(type_spec, "tensor<scalar>[]"), Array(Box::new(Tensor(TypeName::Scalar))));
472 assert_eq!(
473 p(type_spec, "(scalar,scalar[],tensor<scalar>)"),
474 Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
475 );
476 assert_eq!(p(type_spec, "tensor<?>[]"), Array(Box::new(Tensor(TypeName::Any))));
477 assert_eq!(p(type_spec, "scalar[ ]"), Array(Box::new(Single(Scalar))));
478 assert_eq!(
479 p(type_spec, " ( scalar , scalar [ ] , tensor < scalar > ) "),
480 Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
481 );
482 #[cfg(feature = "complex")]
483 assert_eq!(p(type_spec, "tensor<complex>[]"), Array(Box::new(Tensor(TypeName::Complex))));
484 }
485
486 #[test]
487 fn test_fragment_decl_fizz() {
488 let parsed = p(
489 fragment_decl,
490 "fragment fizz<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
491 );
492 assert_eq!(
493 parsed,
494 FragmentDecl {
495 id: "fizz".into(),
496 generic_decl: Some(Some(Scalar)),
497 parameters: vec!(param("shape", Array(Box::new(Single(Integer)))),),
498 results: vec!(result("output", Tensor(Any))),
499 }
500 );
501 }
502
503 #[test]
504 fn test_fragment_decl_logarithmic_quantize() {
505 let parsed = p(fragment_decl,
506 "fragment logarithmic_quantize(x: tensor<scalar>, max: tensor<scalar>, bits: integer ) -> ( y: tensor<scalar> )"
507 );
508 assert_eq!(
509 parsed,
510 FragmentDecl {
511 id: "logarithmic_quantize".into(),
512 generic_decl: None,
513 parameters: vec!(
514 param("x", Tensor(Scalar)),
515 param("max", Tensor(Scalar)),
516 param("bits", Single(Integer))
517 ),
518 results: vec!(result("y", Tensor(Scalar))),
519 }
520 );
521 }
522
523 #[test]
524 fn test_fragment_decl_external() {
525 p(
526 fragment_decl,
527 "fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
528 );
529 }
530
531 #[test]
532 fn test_fragment_reshape() {
533 p(fragments, "fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );");
534 }
535
536 #[test]
537 fn test_fragment_conv() {
538 p(
539 fragments,
540 r#"
541 fragment conv(
542 input: tensor<scalar>,
543 filter: tensor<scalar>,
544 bias: tensor<scalar> = 0.0,
545 border: string = 'constant',
546 padding: (integer,integer)[] = [],
547 stride: integer[] = [],
548 dilation: integer[] = [],
549 groups: integer = 1 )
550 -> ( output: tensor<scalar> );
551 "#,
552 );
553 }
554
555 #[test]
556 fn test_fragment_local_response_normalization() {
557 p(
558 fragments,
559 r#"
560 fragment local_response_normalization(
561 input: tensor<scalar>,
562 size: integer[],
563 alpha: scalar = 1.0,
564 beta: scalar = 0.5,
565 bias: scalar = 1.0 )
566 -> ( output: tensor<scalar> )
567 {
568 sigma = bias + alpha * box(sqr(input), size = size, normalize = true);
569 output = input / (sigma ^ beta);
570 }
571 "#,
572 );
573 }
574
575 #[test]
576 fn test_batch_normalization() {
577 p(
578 fragments,
579 r#"
580 fragment batch_normalization( input: tensor<scalar>, mean: tensor<scalar>, variance: tensor<scalar>, offset: tensor<scalar>, scale: tensor<scalar>, epsilon: scalar )
581 -> ( output: tensor<scalar> )
582 {
583 output = offset + scale * (input - mean) / sqrt(variance + epsilon);
584 }
585 "#,
586 );
587 }
588
589 #[test]
590 fn test_avg_roi_align() {
591 p(
592 fragments,
593 r#"
594 fragment avg_roi_align(
595 input: tensor<scalar>,
596 rois: tensor<scalar>,
597 batch_index: tensor<integer>,
598 output_size: integer[],
599 sampling_rate: integer[],
600 resize_method: string = 'symmetric' )
601 -> ( output: tensor<scalar> )
602 {
603 size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];
604 resized = roi_resample(input, rois, batch_index, output_size = size,
605 method = resize_method);
606 output = avg_pool(resized, size = sampling_rate, stride = sampling_rate);
607 }
608 "#,
609 );
610 }
611
612 #[test]
613 fn test_min_max_linear_quantize() {
614 p(
615 fragments,
616 r#"
617 fragment min_max_linear_quantize(
618 x: tensor<scalar>,
619 min: tensor<scalar>,
620 max: tensor<scalar>,
621 bits: integer,
622 signed: logical,
623 symmetric: logical )
624 -> ( y: tensor<scalar> )
625 {
626 r = scalar(2 ^ bits - 1 - integer(signed && symmetric));
627 z = clamp(x, min, max);
628 p = scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0);
629 q = round((z - min) / (max - min) * r) - p;
630 y = (q + p) / r * (max - min) + min;
631}
632 "#,
633 );
634 }
635
636 #[test]
637 fn test_numeric() {
638 p(numeric_literal, "12.0");
639 }
640
641 #[test]
642 fn test_string() {
643 assert_eq!(p(string_literal, r#""""#), "");
644 assert_eq!(p(string_literal, r#""foo""#), "foo");
645 assert_eq!(p(string_literal, r#"''"#), "");
646 assert_eq!(p(string_literal, r#"'foo'"#), "foo");
647
648 assert_eq!(p(string_literal, r"'f\oo'"), "foo");
649 assert_eq!(p(string_literal, r"'f\'oo'"), "f'oo");
650 assert_eq!(p(string_literal, r#"'f\"oo'"#), "f\"oo");
651 }
652
653 #[test]
654 fn test_identifier() {
655 p(identifier, "foo");
656 assert!(identifier("1").is_err());
657 assert!(identifier("1foo").is_err());
658 }
659
660 #[test]
661 fn test_spacing() {
662 p(space_and_comments, "");
663 p(space_and_comments, "\n");
664 p(space_and_comments, "#comment\n");
665 p(space_and_comments, "#boum");
666 }
667
668 #[test]
669 fn test_spaced() {
670 assert!(spaced(identifier)("foo").is_ok());
671 assert!(spaced(identifier)(" foo ").is_ok());
672 assert!(many1(spaced(identifier))(" foo bar ").is_ok());
673 assert_eq!(
674 many1(spaced(identifier))(" foo bar\n").unwrap().1,
675 &[Identifier("foo".to_string()), Identifier("bar".to_string())]
676 );
677 assert_eq!(
678 many1(spaced(identifier))(" foo # bar\n").unwrap().1,
679 &[Identifier("foo".to_string())]
680 );
681 assert_eq!(
682 many1(spaced(identifier))(" foo # bar\nbaz").unwrap().1,
683 &[Identifier("foo".to_string()), Identifier("baz".to_string())]
684 );
685 }
686
687 #[test]
688 fn test_document() {
689 assert!(document("version 1.0; graph foo() -> () {}").is_ok());
690 }
691
692 #[test]
693 fn test_version() {
694 p(version, "version 1.0;");
695 }
696
697 #[test]
698 fn test_body() {
699 p(body, "{}");
700 p(body, "{foo=bar;}");
701 }
702
703 #[test]
704 fn test_lvalue() {
705 p(lvalue, "foo");
706 p(lvalue, "foo,bar");
707 p(lvalue, "foo , bar");
708 p(lvalue, "(foo,bar)");
709 }
710
711 #[test]
712 fn test_graph_def() {
713 p(graph_def, "graph foo() -> () {}");
714 }
715
716 #[test]
717 fn test_assignment() {
718 p(assignment, "input = external(12);");
719 p(assignment, "input = external(shape = [1, 3, 224, 224]);");
720 p(assignment, "sigma = bias + alpha * box(sqr(input), size = size, normalize = true);");
721 p(assignment, "output = offset + scale * (input - mean) / sqrt(variance + epsilon);");
722 p(
723 assignment,
724 "size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];",
725 );
726 p(assignment, "r = scalar(2 ^ bits - 1 - integer(signed && symmetric));");
727 p(assignment, "output, index = max_pool_with_index(input, size = size, border = border, padding = padding, stride = stride, dilation = dilation);");
728 }
729
730 #[test]
731 fn test_invocation() {
732 p(invocation, "external(12)");
733 p(invocation, "sqrt(var + eps)");
734 }
735
736 #[test]
737 fn test_arguments() {
738 p(argument, "2");
739 p(argument, "12");
740 p(argument, "shape = [1, 3, 224, 224]");
741 }
742
743 #[test]
744 fn test_rvalue() {
745 p(rvalue, "12");
746 p(rvalue, "(0, 0)");
747 p(rvalue, "x ^ 2.0");
748 p(rvalue, "1+2");
749 p(rvalue, "1+sqrt(var)");
750 p(rvalue, "1+sqrt(var+eps)");
751 p(rvalue, "1 + sqrt(var + eps)");
752 p(rvalue, "[for i in range_of(output_size) yield output_size[i] * sampling_rate[i]]");
753 p(rvalue, "scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0)");
754 }
755
756 #[test]
757 fn test_comprehenion() {
758 p(comprehension_expr, "[for i in range_of(output_size) yield output_size * sampling_rate]");
759 }
760
761 #[test]
762 fn test_freeze() {
763 p(
764 document,
765 r#"
766version 1.0;
767
768graph y( x, s, bias ) -> ( y ) {
769 x = external<scalar>(shape = [1, 2, 1, 3]);
770 s = external<scalar>(shape = [2]);
771 bias = external<scalar>(shape = [2]);
772 y = add(
773 mul(
774 mul(
775 sub(
776 x,
777 mul(
778 0.33333334,
779 sum_reduce(
780 x,
781 axes = [0, 2, 3]
782 )
783 )
784 ),
785 rsqrt(
786 add(
787 0.00001,
788 mul(
789 0.33333334,
790 sum_reduce(
791 square(
792 sub(
793 x,
794 mul(
795 0.33333334,
796 sum_reduce(
797 x,
798 axes = [0, 2, 3]
799 )
800 )
801 )
802 ),
803 axes = [0, 2, 3]
804 )
805 )
806 )
807 )
808 ),
809 unsqueeze(
810 unsqueeze(
811 unsqueeze(
812 s,
813 axes = [0]
814 ),
815 axes = [2]
816 ),
817 axes = [2]
818 )
819 ),
820 unsqueeze(
821 unsqueeze(
822 unsqueeze(
823 bias,
824 axes = [0]
825 ),
826 axes = [2]
827 ),
828 axes = [2]
829 )
830 );
831}
832
833"#,
834 );
835 }
836
837 #[test]
838 fn test_fragments() {
839 p(
840 fragments,
841 r#"
842 fragment add( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
843 fragment sub( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
844 "#,
845 );
846 }
847}