1use nom_language::error::{convert_error, VerboseError};
2use tract_core::internal::*;
3
4use nom::branch::alt;
5use nom::combinator::map;
6use nom::{bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*};
7use nom::{Finish, IResult, Parser};
8
9use crate::ast::*;
10
11type R<'i, O> = IResult<&'i str, O, VerboseError<&'i str>>;
12
13pub(super) fn translate_error(e: nom::Err<VerboseError<&str>>) -> TractError {
14 format_err!("{}", e)
15}
16
17#[inline(never)]
18pub fn unwrap_parse<'s, P, O>(input: &'s str, parser: P) -> TractResult<O>
19where
20 P: Parser<&'s str, Output = O, Error = VerboseError<&'s str>>,
21{
22 all_consuming(parser)
23 .parse(input)
24 .finish()
25 .map(|(_, p)| p)
26 .map_err(|e| anyhow!(convert_error(input, e)))
27}
28
29pub fn parse_document(doc: &str) -> TractResult<Document> {
30 unwrap_parse(doc, document)
31}
32
33#[inline(never)]
34pub fn parse_fragments(doc: &str) -> TractResult<Vec<FragmentDef>> {
35 unwrap_parse(doc, fragments)
36}
37
38#[inline(never)]
39pub fn parse_fragment_decl(doc: &str) -> TractResult<FragmentDecl> {
40 unwrap_parse(doc, fragment_decl)
41}
42
43#[inline(never)]
44pub fn parse_parameters(doc: &str) -> TractResult<Vec<Parameter>> {
45 unwrap_parse(doc, parameter_list)
46}
47
48fn document(i: &str) -> R<'_, Document> {
50 map(
51 (version, many0(extension), fragments, graph_def),
52 |(version, extension, fragments, graph_def)| Document {
53 version,
54 extension,
55 fragments,
56 graph_def,
57 },
58 )
59 .parse(i)
60}
61
62fn fragments(i: &str) -> R<'_, Vec<FragmentDef>> {
63 many0(fragment_def).parse(i)
64}
65
66fn version(i: &str) -> R<'_, NumericLiteral> {
69 preceded(stag("version"), cut(terminated(numeric_literal, stag(";")))).parse(i)
70}
71
72fn extension(i: &str) -> R<'_, (Identifier, String)> {
75 delimited(
76 stag("extension"),
77 pair(spaced(identifier), map(take_until(";"), |s: &str| s.to_string())),
78 stag(";"),
79 )
80 .parse(i)
81}
82
83fn fragment_def(i: &str) -> R<'_, FragmentDef> {
87 spaced(map(
88 pair(fragment_decl, alt((map(body, Some), map(stag(";"), |_| None)))),
89 |(decl, body)| FragmentDef { decl, body },
90 ))
91 .parse(i)
92}
93
94fn fragment_decl(i: &str) -> R<'_, FragmentDecl> {
96 preceded(stag("fragment"), cut(commited_fragment_decl)).parse(i)
97}
98
99fn commited_fragment_decl(i: &str) -> R<'_, FragmentDecl> {
100 let (i, id) = identifier(i)?;
101 let (i, generic_decl) = opt(generic_decl).parse(i)?;
102 let (i, _) = stag("(").parse(i)?;
103 let (i, parameters) = cut(parameter_list).parse(i)?;
104 let (i, _) = stag(")").parse(i)?;
105 let (i, _) = stag("->").parse(i)?;
106 let (i, _) = stag("(").parse(i)?;
107 let (i, results) = cut(result_list).parse(i)?;
108 let (i, _) = stag(")").parse(i)?;
109 Ok((i, FragmentDecl { id, parameters, results, generic_decl }))
110}
111
112fn generic_decl(i: &str) -> R<'_, Option<TypeName>> {
114 let (i, _) = stag("<").parse(i)?;
115 let (i, _) = stag("?").parse(i)?;
116 let (i, name) = opt(preceded(stag("="), type_name)).parse(i)?;
117 let (i, _) = stag(">").parse(i)?;
118 Ok((i, name))
119}
120
121fn parameter_list(i: &str) -> R<'_, Vec<Parameter>> {
123 separated_list0(stag(","), parameter).parse(i)
124}
125
126fn result_list(i: &str) -> R<'_, Vec<Result_>> {
128 separated_list0(stag(","), result).parse(i)
129}
130
131fn parameter(i: &str) -> R<'_, Parameter> {
133 map(
134 pair(
135 separated_pair(identifier, stag(":"), cut(type_spec)),
136 opt(preceded(stag("="), literal_expr)),
137 ),
138 |((id, spec), lit)| Parameter { id, spec, lit, doc: None },
139 )
140 .parse(i)
141}
142
143fn result(i: &str) -> R<'_, Result_> {
145 map(separated_pair(identifier, stag(":"), cut(type_spec)), |(id, spec)| Result_ { id, spec })
146 .parse(i)
147}
148
149fn literal_expr(i: &str) -> R<'_, Literal> {
150 spaced(alt((
151 literal,
152 map(delimited(stag("["), separated_list0(stag(","), literal), stag("]")), Literal::Array),
153 map(delimited(stag("("), separated_list0(stag(","), literal), stag(")")), Literal::Tuple),
154 )))
155 .parse(i)
156}
157
158fn type_spec(i: &str) -> R<'_, TypeSpec> {
160 fn non_array_type(i: &str) -> R<'_, TypeSpec> {
161 alt((tuple_type_spec, map(type_name, TypeSpec::Single), tensor_type_spec)).parse(i)
162 }
163 alt((
164 (map(terminated(non_array_type, pair(stag("["), stag("]"))), |t| {
165 TypeSpec::Array(Box::new(t))
166 })),
167 non_array_type,
168 ))
169 .parse(i)
170}
171
172fn type_name(i: &str) -> R<'_, TypeName> {
174 spaced(alt((
175 map(tag("integer"), |_| TypeName::Integer),
176 map(tag("scalar"), |_| TypeName::Scalar),
177 map(tag("logical"), |_| TypeName::Logical),
178 map(tag("string"), |_| TypeName::String),
179 #[cfg(feature = "complex")]
180 map(tag("complex"), |_| TypeName::Complex),
181 map(tag("?"), |_| TypeName::Any),
182 )))
183 .parse(i)
184}
185
186fn tensor_type_spec(i: &str) -> R<'_, TypeSpec> {
188 map(delimited(pair(stag("tensor"), stag("<")), type_name, stag(">")), TypeSpec::Tensor).parse(i)
189}
190
191fn tuple_type_spec(i: &str) -> R<'_, TypeSpec> {
193 map(delimited(stag("("), separated_list0(stag(","), type_spec), stag(")")), TypeSpec::Tuple)
194 .parse(i)
195}
196
197fn graph_def(i: &str) -> R<'_, GraphDef> {
203 let (i, _) = stag("graph").parse(i)?;
204 let (i, id) = identifier(i)?;
205 let (i, _) = stag("(").parse(i)?;
206 let (i, parameters) = separated_list0(stag(","), identifier).parse(i)?;
207 let (i, _) = stag(")").parse(i)?;
208 let (i, _) = stag("->").parse(i)?;
209 let (i, _) = stag("(").parse(i)?;
210 let (i, results) = separated_list0(stag(","), identifier).parse(i)?;
211 let (i, _) = stag(")").parse(i)?;
212 let (i, body) = spaced(body).parse(i)?;
213 Ok((i, GraphDef { id, parameters, results, body }))
214}
215
216fn body(i: &str) -> R<'_, Vec<Assignment>> {
220 delimited(stag("{"), many0(assignment), stag("}")).parse(i)
221}
222
223fn assignment(i: &str) -> R<'_, Assignment> {
225 spaced(terminated(
226 map(separated_pair(lvalue, stag("="), rvalue), |(left, right)| Assignment { left, right }),
227 stag(";"),
228 ))
229 .parse(i)
230}
231
232fn lvalue(i: &str) -> R<'_, LValue> {
236 fn inner_lvalue(i: &str) -> R<'_, LValue> {
237 alt((
238 map(
239 delimited(stag("["), separated_list0(stag(","), inner_lvalue), stag("]")),
240 LValue::Array,
241 ),
242 map(
243 delimited(stag("("), separated_list0(stag(","), inner_lvalue), stag(")")),
244 LValue::Tuple,
245 ),
246 map(spaced(identifier), LValue::Identifier),
247 ))
248 .parse(i)
249 }
250
251 map(separated_list0(stag(","), inner_lvalue), |mut iv| {
252 if iv.len() == 1 {
253 iv.remove(0)
254 } else {
255 LValue::Tuple(iv)
256 }
257 })
258 .parse(i)
259}
260
261fn invocation(i: &str) -> R<'_, Invocation> {
263 let (i, id) = spaced(identifier).parse(i)?;
264 let (i, generic_type_name) = opt(delimited(stag("<"), type_name, stag(">"))).parse(i)?;
265 let (i, _) = stag("(").parse(i)?;
266 let (i, arguments) = argument_list.parse(i)?;
267 let (i, _) = stag(")").parse(i)?;
268 Ok((i, Invocation { id, generic_type_name, arguments }))
269}
270
271fn argument_list(i: &str) -> R<'_, Vec<Argument>> {
273 separated_list0(stag(","), argument).parse(i)
274}
275
276fn argument(i: &str) -> R<'_, Argument> {
278 spaced(map(pair(opt(terminated(identifier, stag("="))), rvalue), |(id, rvalue)| Argument {
279 id,
280 rvalue,
281 }))
282 .parse(i)
283}
284
285fn rvalue(i: &str) -> R<'_, RValue> {
289 fn atom(i: &str) -> R<'_, RValue> {
290 spaced(alt((
291 map(invocation, RValue::Invocation),
292 map(literal, RValue::Literal),
293 map(identifier, RValue::Identifier),
294 map(pair(spaced(recognize(one_of("+-!"))), rvalue), |(op, rv)| {
295 RValue::Unary(op.into(), Box::new(rv))
296 }),
297 map(delimited(tag("("), separated_list0(stag(","), rvalue), tag(")")), |mut rvs| {
298 if rvs.len() == 1 {
299 rvs.remove(0)
300 } else {
301 RValue::Tuple(rvs)
302 }
303 }),
304 map(comprehension_expr, |c| RValue::Comprehension(Box::new(c))),
305 map(delimited(tag("["), separated_list0(stag(","), rvalue), tag("]")), |rvs| {
306 RValue::Array(rvs)
307 }),
308 )))
309 .parse(i)
310 }
311 macro_rules! bin {
312 ($name:ident, $operand: ident, $operator: expr) => {
313 fn $name(i: &str) -> R<'_, RValue> {
314 let (i, init) = $operand(i)?;
315 fold_many0(
316 pair($operator, $operand),
317 move || init.clone(),
318 |left, (op, right)| {
319 RValue::Binary(Box::new(left), op.to_string(), Box::new(right))
320 },
321 )
322 .parse(i)
323 }
324 };
325 }
326
327 fn sub(i: &str) -> R<'_, RValue> {
329 alt((
330 map(
331 pair(
332 atom,
333 delimited(
334 stag("["),
335 alt((
336 map(separated_pair(opt(rvalue), stag(":"), opt(rvalue)), |(a, b)| {
337 Subscript::Range(a, b)
338 }),
339 map(rvalue, Subscript::Single),
340 )),
341 stag("]"),
342 ),
343 ),
344 |(rv, range)| RValue::Subscript(Box::new(rv), Box::new(range)),
345 ),
346 atom,
347 ))
348 .parse(i)
349 }
350
351 bin!(exp, sub, tag("^"));
352 bin!(mul, exp, one_of("*/"));
353 bin!(add, mul, one_of("+-"));
354 bin!(comp, add, alt((tag("=="), tag("!="), tag("<"), tag(">"), tag("<="), tag(">="))));
355 bin!(boolean, comp, alt((tag("||"), tag("&&"))));
356 bin!(in_for, boolean, tag("in"));
357
358 fn ite(i: &str) -> R<'_, RValue> {
360 let (i, leftmost) = in_for(i)?;
361 let (i, _) = space_and_comments(i)?;
362 if i.starts_with("if") {
363 let (i, _) = stag("if").parse(i)?;
364 let (i, cond) = in_for(i)?;
365 let (i, _) = stag("else").parse(i)?;
366 let (i, otherwise) = in_for(i)?;
367 Ok((i, RValue::IfThenElse(Box::new(IfThenElse { cond, then: leftmost, otherwise }))))
368 } else {
369 Ok((i, leftmost))
370 }
371 }
372
373 ite(i)
374}
375
376fn comprehension_expr(i: &str) -> R<'_, Comprehension> {
378 delimited(
379 pair(stag("["), stag("for")),
380 map(separated_pair(loop_iters, stag("yield"), rvalue), |(loop_iters, yields)| {
381 Comprehension { loop_iters, filter: None, yields }
382 }),
383 stag("]"),
384 )
385 .parse(i)
386}
387
388fn loop_iters(i: &str) -> R<'_, Vec<(Identifier, RValue)>> {
391 separated_list0(stag(","), separated_pair(identifier, stag("in"), rvalue)).parse(i)
392}
393
394pub(super) fn identifier(i: &str) -> R<'_, Identifier> {
399 alt((escaped_identifier, direct_identifier)).parse(i)
400}
401
402pub(super) fn direct_identifier(i: &str) -> R<'_, Identifier> {
403 map(
404 recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))),
405 Identifier::from,
406 )
407 .parse(i)
408}
409
410pub(super) fn escaped_identifier(i: &str) -> R<'_, Identifier> {
411 map(preceded(tag("i"), string_literal), Identifier).parse(i)
412}
413
414fn literal(i: &str) -> R<'_, Literal> {
416 spaced(alt((
417 map(numeric_literal, Literal::Numeric),
418 map(string_literal, Literal::String),
419 map(logical_literal, Literal::Logical),
420 )))
421 .parse(i)
422}
423
424pub(super) fn numeric_literal(i: &str) -> R<'_, String> {
425 fn exp_part(i: &str) -> R<'_, &str> {
426 recognize((one_of("eE"), opt(tag("-")), digit1)).parse(i)
427 }
428 fn frac_part(i: &str) -> R<'_, &str> {
429 recognize((tag("."), digit0)).parse(i)
430 }
431 spaced(map(
432 recognize((opt(tag("-")), alt((digit1, tag("inf"))), opt(frac_part), opt(exp_part))),
433 |s: &str| s.to_owned(),
434 ))
435 .parse(i)
436}
437
438fn string_literal(i: &str) -> R<'_, String> {
439 fn inner(i: &str) -> R<'_, String> {
440 map(
441 many0(alt((
442 preceded(tag("\\"), nom::character::complete::anychar),
443 nom::character::complete::none_of("\\\"'"),
444 ))),
445 |v: Vec<char>| v.into_iter().collect(),
446 )
447 .parse(i)
448 }
449 map(alt((delimited(tag("'"), inner, tag("'")), delimited(tag("\""), inner, tag("\"")))), |s| s)
450 .parse(i)
451}
452
453pub(super) fn logical_literal(i: &str) -> R<'_, bool> {
454 spaced(alt((map(tag("true"), |_| true), map(tag("false"), |_| false)))).parse(i)
455}
456
457fn space_and_comments(i: &str) -> R<'_, ()> {
460 map(
461 many0(alt((recognize(one_of(" \t\n\r")), recognize((tag("#"), many0(none_of("\r\n"))))))),
462 |_| (),
463 )
464 .parse(i)
465}
466
467fn spaced<'s, O, F>(it: F) -> impl Parser<&'s str, Output = O, Error = VerboseError<&'s str>>
468where
469 F: Parser<&'s str, Output = O, Error = VerboseError<&'s str>>,
470{
471 delimited(space_and_comments, it, space_and_comments)
472}
473
474pub(super) fn stag<'s>(
475 t: &'static str,
476) -> impl Parser<&'s str, Output = &'s str, Error = VerboseError<&'s str>> {
477 spaced(tag(t))
478}
479
480#[cfg(test)]
481mod test {
482 use super::*;
483 use TypeName::*;
484 use TypeSpec::*;
485
486 fn p<'s, P, O, E>(parser: P, i: &'s str) -> O
487 where
488 O: std::fmt::Debug,
489 P: Fn(&'s str) -> IResult<&'s str, O, E>,
490 E: nom::error::ParseError<&'s str> + std::fmt::Debug,
491 {
492 let res = all_consuming(parser).parse(i).unwrap();
493 res.1
494 }
495
496 fn param(s: impl Into<std::string::String>, t: TypeSpec) -> Parameter {
497 Parameter { id: Identifier(s.into()), spec: t, lit: None, doc: None }
498 }
499
500 fn result(s: impl Into<std::string::String>, t: TypeSpec) -> Result_ {
501 Result_ { id: Identifier(s.into()), spec: t }
502 }
503
504 #[test]
505 fn test_type_spec() {
506 assert_eq!(p(type_spec, "scalar"), Single(Scalar));
507 assert_eq!(p(type_spec, "scalar[]"), Array(Box::new(Single(Scalar))));
508 assert_eq!(p(type_spec, "tensor<scalar>[]"), Array(Box::new(Tensor(TypeName::Scalar))));
509 assert_eq!(
510 p(type_spec, "(scalar,scalar[],tensor<scalar>)"),
511 Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
512 );
513 assert_eq!(p(type_spec, "tensor<?>[]"), Array(Box::new(Tensor(TypeName::Any))));
514 assert_eq!(p(type_spec, "scalar[ ]"), Array(Box::new(Single(Scalar))));
515 assert_eq!(
516 p(type_spec, " ( scalar , scalar [ ] , tensor < scalar > ) "),
517 Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
518 );
519 #[cfg(feature = "complex")]
520 assert_eq!(p(type_spec, "tensor<complex>[]"), Array(Box::new(Tensor(TypeName::Complex))));
521 }
522
523 #[test]
524 fn test_fragment_decl_fizz() {
525 let parsed = p(
526 fragment_decl,
527 "fragment fizz<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
528 );
529 assert_eq!(
530 parsed,
531 FragmentDecl {
532 id: "fizz".into(),
533 generic_decl: Some(Some(Scalar)),
534 parameters: vec!(param("shape", Array(Box::new(Single(Integer)))),),
535 results: vec!(result("output", Tensor(Any))),
536 }
537 );
538 }
539
540 #[test]
541 fn test_fragment_decl_logarithmic_quantize() {
542 let parsed = p(fragment_decl,
543 "fragment logarithmic_quantize(x: tensor<scalar>, max: tensor<scalar>, bits: integer ) -> ( y: tensor<scalar> )"
544 );
545 assert_eq!(
546 parsed,
547 FragmentDecl {
548 id: "logarithmic_quantize".into(),
549 generic_decl: None,
550 parameters: vec!(
551 param("x", Tensor(Scalar)),
552 param("max", Tensor(Scalar)),
553 param("bits", Single(Integer))
554 ),
555 results: vec!(result("y", Tensor(Scalar))),
556 }
557 );
558 }
559
560 #[test]
561 fn test_fragment_decl_external() {
562 p(
563 fragment_decl,
564 "fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
565 );
566 }
567
568 #[test]
569 fn test_fragment_reshape() {
570 p(fragments, "fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );");
571 }
572
573 #[test]
574 fn test_fragment_conv() {
575 p(
576 fragments,
577 r#"
578 fragment conv(
579 input: tensor<scalar>,
580 filter: tensor<scalar>,
581 bias: tensor<scalar> = 0.0,
582 border: string = 'constant',
583 padding: (integer,integer)[] = [],
584 stride: integer[] = [],
585 dilation: integer[] = [],
586 groups: integer = 1 )
587 -> ( output: tensor<scalar> );
588 "#,
589 );
590 }
591
592 #[test]
593 fn test_fragment_local_response_normalization() {
594 p(
595 fragments,
596 r#"
597 fragment local_response_normalization(
598 input: tensor<scalar>,
599 size: integer[],
600 alpha: scalar = 1.0,
601 beta: scalar = 0.5,
602 bias: scalar = 1.0 )
603 -> ( output: tensor<scalar> )
604 {
605 sigma = bias + alpha * box(sqr(input), size = size, normalize = true);
606 output = input / (sigma ^ beta);
607 }
608 "#,
609 );
610 }
611
612 #[test]
613 fn test_batch_normalization() {
614 p(
615 fragments,
616 r#"
617 fragment batch_normalization( input: tensor<scalar>, mean: tensor<scalar>, variance: tensor<scalar>, offset: tensor<scalar>, scale: tensor<scalar>, epsilon: scalar )
618 -> ( output: tensor<scalar> )
619 {
620 output = offset + scale * (input - mean) / sqrt(variance + epsilon);
621 }
622 "#,
623 );
624 }
625
626 #[test]
627 fn test_avg_roi_align() {
628 p(
629 fragments,
630 r#"
631 fragment avg_roi_align(
632 input: tensor<scalar>,
633 rois: tensor<scalar>,
634 batch_index: tensor<integer>,
635 output_size: integer[],
636 sampling_rate: integer[],
637 resize_method: string = 'symmetric' )
638 -> ( output: tensor<scalar> )
639 {
640 size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];
641 resized = roi_resample(input, rois, batch_index, output_size = size,
642 method = resize_method);
643 output = avg_pool(resized, size = sampling_rate, stride = sampling_rate);
644 }
645 "#,
646 );
647 }
648
649 #[test]
650 fn test_min_max_linear_quantize() {
651 p(
652 fragments,
653 r#"
654 fragment min_max_linear_quantize(
655 x: tensor<scalar>,
656 min: tensor<scalar>,
657 max: tensor<scalar>,
658 bits: integer,
659 signed: logical,
660 symmetric: logical )
661 -> ( y: tensor<scalar> )
662 {
663 r = scalar(2 ^ bits - 1 - integer(signed && symmetric));
664 z = clamp(x, min, max);
665 p = scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0);
666 q = round((z - min) / (max - min) * r) - p;
667 y = (q + p) / r * (max - min) + min;
668}
669 "#,
670 );
671 }
672
673 #[test]
674 fn test_numeric() {
675 p(numeric_literal, "12.0");
676 }
677
678 #[test]
679 fn test_string() {
680 assert_eq!(p(string_literal, r#""""#), "");
681 assert_eq!(p(string_literal, r#""foo""#), "foo");
682 assert_eq!(p(string_literal, r#"''"#), "");
683 assert_eq!(p(string_literal, r#"'foo'"#), "foo");
684
685 assert_eq!(p(string_literal, r"'f\oo'"), "foo");
686 assert_eq!(p(string_literal, r"'f\'oo'"), "f'oo");
687 assert_eq!(p(string_literal, r#"'f\"oo'"#), "f\"oo");
688 }
689
690 #[test]
691 fn test_identifier() {
692 p(identifier, "foo");
693 assert!(identifier("1").is_err());
694 assert!(identifier("1foo").is_err());
695 }
696
697 #[test]
698 fn test_spacing() {
699 p(space_and_comments, "");
700 p(space_and_comments, "\n");
701 p(space_and_comments, "#comment\n");
702 p(space_and_comments, "#boum");
703 }
704
705 #[test]
706 fn test_spaced() {
707 assert!(spaced(identifier).parse("foo").is_ok());
708 assert!(spaced(identifier).parse(" foo ").is_ok());
709 assert!(many1(spaced(identifier)).parse(" foo bar ").is_ok());
710 assert_eq!(
711 many1(spaced(identifier)).parse(" foo bar\n").unwrap().1,
712 &[Identifier("foo".to_string()), Identifier("bar".to_string())]
713 );
714 assert_eq!(
715 many1(spaced(identifier)).parse(" foo # bar\n").unwrap().1,
716 &[Identifier("foo".to_string())]
717 );
718 assert_eq!(
719 many1(spaced(identifier)).parse(" foo # bar\nbaz").unwrap().1,
720 &[Identifier("foo".to_string()), Identifier("baz".to_string())]
721 );
722 }
723
724 #[test]
725 fn test_document() {
726 assert!(document("version 1.0; graph foo() -> () {}").is_ok());
727 }
728
729 #[test]
730 fn test_version() {
731 p(version, "version 1.0;");
732 }
733
734 #[test]
735 fn test_body() {
736 p(body, "{}");
737 p(body, "{foo=bar;}");
738 }
739
740 #[test]
741 fn test_lvalue() {
742 p(lvalue, "foo");
743 p(lvalue, "foo,bar");
744 p(lvalue, "foo , bar");
745 p(lvalue, "(foo,bar)");
746 }
747
748 #[test]
749 fn test_graph_def() {
750 p(graph_def, "graph foo() -> () {}");
751 }
752
753 #[test]
754 fn test_assignment() {
755 p(assignment, "input = external(12);");
756 p(assignment, "input = external(shape = [1, 3, 224, 224]);");
757 p(assignment, "sigma = bias + alpha * box(sqr(input), size = size, normalize = true);");
758 p(assignment, "output = offset + scale * (input - mean) / sqrt(variance + epsilon);");
759 p(
760 assignment,
761 "size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];",
762 );
763 p(assignment, "r = scalar(2 ^ bits - 1 - integer(signed && symmetric));");
764 p(assignment, "output, index = max_pool_with_index(input, size = size, border = border, padding = padding, stride = stride, dilation = dilation);");
765 }
766
767 #[test]
768 fn test_invocation() {
769 p(invocation, "external(12)");
770 p(invocation, "sqrt(var + eps)");
771 }
772
773 #[test]
774 fn test_arguments() {
775 p(argument, "2");
776 p(argument, "12");
777 p(argument, "shape = [1, 3, 224, 224]");
778 }
779
780 #[test]
781 fn test_rvalue() {
782 p(rvalue, "12");
783 p(rvalue, "(0, 0)");
784 p(rvalue, "x ^ 2.0");
785 p(rvalue, "1+2");
786 p(rvalue, "1+sqrt(var)");
787 p(rvalue, "1+sqrt(var+eps)");
788 p(rvalue, "1 + sqrt(var + eps)");
789 p(rvalue, "[for i in range_of(output_size) yield output_size[i] * sampling_rate[i]]");
790 p(rvalue, "scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0)");
791 }
792
793 #[test]
794 fn test_comprehenion() {
795 p(comprehension_expr, "[for i in range_of(output_size) yield output_size * sampling_rate]");
796 }
797
798 #[test]
799 fn test_freeze() {
800 p(
801 document,
802 r#"
803version 1.0;
804
805graph y( x, s, bias ) -> ( y ) {
806 x = external<scalar>(shape = [1, 2, 1, 3]);
807 s = external<scalar>(shape = [2]);
808 bias = external<scalar>(shape = [2]);
809 y = add(
810 mul(
811 mul(
812 sub(
813 x,
814 mul(
815 0.33333334,
816 sum_reduce(
817 x,
818 axes = [0, 2, 3]
819 )
820 )
821 ),
822 rsqrt(
823 add(
824 0.00001,
825 mul(
826 0.33333334,
827 sum_reduce(
828 square(
829 sub(
830 x,
831 mul(
832 0.33333334,
833 sum_reduce(
834 x,
835 axes = [0, 2, 3]
836 )
837 )
838 )
839 ),
840 axes = [0, 2, 3]
841 )
842 )
843 )
844 )
845 ),
846 unsqueeze(
847 unsqueeze(
848 unsqueeze(
849 s,
850 axes = [0]
851 ),
852 axes = [2]
853 ),
854 axes = [2]
855 )
856 ),
857 unsqueeze(
858 unsqueeze(
859 unsqueeze(
860 bias,
861 axes = [0]
862 ),
863 axes = [2]
864 ),
865 axes = [2]
866 )
867 );
868}
869
870"#,
871 );
872 }
873
874 #[test]
875 fn test_fragments() {
876 p(
877 fragments,
878 r#"
879 fragment add( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
880 fragment sub( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
881 "#,
882 );
883 }
884}