use tract_core::internal::*;
use nom::branch::alt;
use nom::combinator::map;
use nom::IResult;
use nom::{bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*};
use crate::ast::*;
pub(super) fn translate_error<E: std::fmt::Debug>(e: E) -> TractError {
    format_err!("Fail to parse NNEF document: {:?}", e)
}
#[inline(never)]
pub fn parse_document(doc: &str) -> TractResult<Document> {
    all_consuming(document)(doc).map(|pair| pair.1).map_err(translate_error)
}
#[inline(never)]
pub fn parse_fragments(doc: &str) -> TractResult<Vec<FragmentDef>> {
    all_consuming(fragments)(doc).map(|pair| pair.1).map_err(translate_error)
}
#[inline(never)]
pub fn parse_fragment_decl(doc: &str) -> TractResult<FragmentDecl> {
    all_consuming(fragment_decl)(doc).map(|pair| pair.1).map_err(translate_error)
}
#[inline(never)]
pub fn parse_parameters(doc: &str) -> TractResult<Vec<Parameter>> {
    all_consuming(parameter_list)(doc).map(|pair| pair.1).map_err(translate_error)
}
fn document(i: &str) -> IResult<&str, Document> {
    map(
        tuple((version, many0(extension), fragments, graph_def)),
        |(version, extension, fragments, graph_def)| Document {
            version,
            extension,
            fragments,
            graph_def,
        },
    )(i)
}
fn fragments(i: &str) -> IResult<&str, Vec<FragmentDef>> {
    many0(fragment_def)(i)
}
fn version(i: &str) -> IResult<&str, NumericLiteral> {
    delimited(stag("version"), numeric_literal, stag(";"))(i)
}
fn extension(i: &str) -> IResult<&str, (Identifier, String)> {
    delimited(
        stag("extension"),
        pair(spaced(identifier), map(take_until(";"), |s: &str| s.to_string())),
        stag(";"),
    )(i)
}
fn fragment_def(i: &str) -> IResult<&str, FragmentDef> {
    spaced(map(
        pair(fragment_decl, alt((map(body, Some), map(stag(";"), |_| None)))),
        |(decl, body)| FragmentDef { decl, body },
    ))(i)
}
fn fragment_decl(i: &str) -> IResult<&str, FragmentDecl> {
    let (i, _) = stag("fragment")(i)?;
    let (i, id) = identifier(i)?;
    let (i, generic_decl) = opt(generic_decl)(i)?;
    let (i, _) = stag("(")(i)?;
    let (i, parameters) = parameter_list(i)?;
    let (i, _) = stag(")")(i)?;
    let (i, _) = stag("->")(i)?;
    let (i, _) = stag("(")(i)?;
    let (i, results) = result_list(i)?;
    let (i, _) = stag(")")(i)?;
    Ok((i, FragmentDecl { id, parameters, results, generic_decl }))
}
fn generic_decl(i: &str) -> IResult<&str, Option<TypeName>> {
    let (i, _) = stag("<")(i)?;
    let (i, _) = stag("?")(i)?;
    let (i, name) = opt(preceded(stag("="), type_name))(i)?;
    let (i, _) = stag(">")(i)?;
    Ok((i, name))
}
fn parameter_list(i: &str) -> IResult<&str, Vec<Parameter>> {
    separated_list0(stag(","), parameter)(i)
}
fn result_list(i: &str) -> IResult<&str, Vec<Result_>> {
    separated_list0(stag(","), result)(i)
}
fn parameter(i: &str) -> IResult<&str, Parameter> {
    map(
        pair(
            separated_pair(identifier, stag(":"), type_spec),
            opt(preceded(stag("="), literal_expr)),
        ),
        |((id, spec), lit)| Parameter { id, spec, lit, doc: None },
    )(i)
}
fn result(i: &str) -> IResult<&str, Result_> {
    map(separated_pair(identifier, stag(":"), type_spec), |(id, spec)| Result_ { id, spec })(i)
}
fn literal_expr(i: &str) -> IResult<&str, Literal> {
    spaced(alt((
        literal,
        map(delimited(stag("["), separated_list0(stag(","), literal), stag("]")), Literal::Array),
        map(delimited(stag("("), separated_list0(stag(","), literal), stag(")")), Literal::Tuple),
    )))(i)
}
fn type_spec(i: &str) -> IResult<&str, TypeSpec> {
    fn non_array_type(i: &str) -> IResult<&str, TypeSpec> {
        alt((tuple_type_spec, map(type_name, TypeSpec::Single), tensor_type_spec))(i)
    }
    alt((
        (map(terminated(non_array_type, pair(stag("["), stag("]"))), |t| {
            TypeSpec::Array(Box::new(t))
        })),
        non_array_type,
    ))(i)
}
fn type_name(i: &str) -> IResult<&str, TypeName> {
    spaced(alt((
        map(tag("integer"), |_| TypeName::Integer),
        map(tag("scalar"), |_| TypeName::Scalar),
        map(tag("logical"), |_| TypeName::Logical),
        map(tag("string"), |_| TypeName::String),
        #[cfg(feature = "complex")]
        map(tag("complex"), |_| TypeName::Complex),
        map(tag("?"), |_| TypeName::Any),
    )))(i)
}
fn tensor_type_spec(i: &str) -> IResult<&str, TypeSpec> {
    map(delimited(pair(stag("tensor"), stag("<")), type_name, stag(">")), TypeSpec::Tensor)(i)
}
fn tuple_type_spec(i: &str) -> IResult<&str, TypeSpec> {
    map(delimited(stag("("), separated_list0(stag(","), type_spec), stag(")")), TypeSpec::Tuple)(i)
}
fn graph_def(i: &str) -> IResult<&str, GraphDef> {
    let (i, _) = stag("graph")(i)?;
    let (i, id) = identifier(i)?;
    let (i, _) = stag("(")(i)?;
    let (i, parameters) = separated_list0(stag(","), identifier)(i)?;
    let (i, _) = stag(")")(i)?;
    let (i, _) = stag("->")(i)?;
    let (i, _) = stag("(")(i)?;
    let (i, results) = separated_list0(stag(","), identifier)(i)?;
    let (i, _) = stag(")")(i)?;
    let (i, body) = spaced(body)(i)?;
    Ok((i, GraphDef { id, parameters, results, body }))
}
fn body(i: &str) -> IResult<&str, Vec<Assignment>> {
    delimited(stag("{"), many0(assignment), stag("}"))(i)
}
fn assignment(i: &str) -> IResult<&str, Assignment> {
    spaced(terminated(
        map(separated_pair(lvalue, stag("="), rvalue), |(left, right)| Assignment { left, right }),
        stag(";"),
    ))(i)
}
fn lvalue(i: &str) -> IResult<&str, LValue> {
    fn inner_lvalue(i: &str) -> IResult<&str, LValue> {
        alt((
            map(
                delimited(stag("["), separated_list0(stag(","), inner_lvalue), stag("]")),
                LValue::Array,
            ),
            map(
                delimited(stag("("), separated_list0(stag(","), inner_lvalue), stag(")")),
                LValue::Tuple,
            ),
            map(spaced(identifier), LValue::Identifier),
        ))(i)
    }
    map(separated_list0(stag(","), inner_lvalue), |mut iv| {
        if iv.len() == 1 {
            iv.remove(0)
        } else {
            LValue::Tuple(iv)
        }
    })(i)
}
fn invocation(i: &str) -> IResult<&str, Invocation> {
    let (i, id) = spaced(identifier)(i)?;
    let (i, generic_type_name) = opt(delimited(stag("<"), type_name, stag(">")))(i)?;
    let (i, _) = stag("(")(i)?;
    let (i, arguments) = argument_list(i)?;
    let (i, _) = stag(")")(i)?;
    Ok((i, Invocation { id, generic_type_name, arguments }))
}
fn argument_list(i: &str) -> IResult<&str, Vec<Argument>> {
    separated_list0(stag(","), argument)(i)
}
fn argument(i: &str) -> IResult<&str, Argument> {
    spaced(map(pair(opt(terminated(identifier, stag("="))), rvalue), |(id, rvalue)| Argument {
        id,
        rvalue,
    }))(i)
}
fn rvalue(i: &str) -> IResult<&str, RValue> {
    fn atom(i: &str) -> IResult<&str, RValue> {
        spaced(alt((
            map(invocation, RValue::Invocation),
            map(literal, RValue::Literal),
            map(identifier, RValue::Identifier),
            map(pair(spaced(recognize(one_of("+-!"))), rvalue), |(op, rv)| {
                RValue::Unary(op.into(), Box::new(rv))
            }),
            map(delimited(tag("("), separated_list0(stag(","), rvalue), tag(")")), |mut rvs| {
                if rvs.len() == 1 {
                    rvs.remove(0)
                } else {
                    RValue::Tuple(rvs)
                }
            }),
            map(comprehension_expr, |c| RValue::Comprehension(Box::new(c))),
            map(delimited(tag("["), separated_list0(stag(","), rvalue), tag("]")), |rvs| {
                RValue::Array(rvs)
            }),
        )))(i)
    }
    macro_rules! bin {
        ($name:ident, $operand: ident, $operator: expr) => {
            fn $name(i: &str) -> IResult<&str, RValue> {
                let (i, init) = $operand(i)?;
                fold_many0(
                    pair($operator, $operand),
                    move || init.clone(),
                    |left, (op, right)| {
                        RValue::Binary(Box::new(left), op.to_string(), Box::new(right))
                    },
                )(i)
            }
        };
    }
    fn sub(i: &str) -> IResult<&str, RValue> {
        alt((
            map(
                pair(
                    atom,
                    delimited(
                        stag("["),
                        alt((
                            map(separated_pair(opt(rvalue), stag(":"), opt(rvalue)), |(a, b)| {
                                Subscript::Range(a, b)
                            }),
                            map(rvalue, Subscript::Single),
                        )),
                        stag("]"),
                    ),
                ),
                |(rv, range)| RValue::Subscript(Box::new(rv), Box::new(range)),
            ),
            atom,
        ))(i)
    }
    bin!(exp, sub, tag("^"));
    bin!(mul, exp, one_of("*/"));
    bin!(add, mul, one_of("+-"));
    bin!(comp, add, alt((tag("=="), tag("!="), tag("<"), tag(">"), tag("<="), tag(">="))));
    bin!(boolean, comp, alt((tag("||"), tag("&&"))));
    bin!(in_for, boolean, tag("in"));
    fn ite(i: &str) -> IResult<&str, RValue> {
        let (i, leftmost) = in_for(i)?;
        let (i, _) = space_and_comments(i)?;
        if i.starts_with("if") {
            let (i, _) = stag("if")(i)?;
            let (i, cond) = in_for(i)?;
            let (i, _) = stag("else")(i)?;
            let (i, otherwise) = in_for(i)?;
            Ok((i, RValue::IfThenElse(Box::new(IfThenElse { cond, then: leftmost, otherwise }))))
        } else {
            Ok((i, leftmost))
        }
    }
    ite(i)
}
fn comprehension_expr(i: &str) -> IResult<&str, Comprehension> {
    delimited(
        pair(stag("["), stag("for")),
        map(separated_pair(loop_iters, stag("yield"), rvalue), |(loop_iters, yields)| {
            Comprehension { loop_iters, filter: None, yields }
        }),
        stag("]"),
    )(i)
}
fn loop_iters(i: &str) -> IResult<&str, Vec<(Identifier, RValue)>> {
    separated_list0(stag(","), separated_pair(identifier, stag("in"), rvalue))(i)
}
pub(super) fn identifier(i: &str) -> IResult<&str, Identifier> {
    alt((escaped_identifier, direct_identifier))(i)
}
pub(super) fn direct_identifier(i: &str) -> IResult<&str, Identifier> {
    map(
        recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))),
        Identifier::from,
    )(i)
}
pub(super) fn escaped_identifier(i: &str) -> IResult<&str, Identifier> {
    map(preceded(tag("i"), string_literal), Identifier)(i)
}
fn literal(i: &str) -> IResult<&str, Literal> {
    spaced(alt((
        map(numeric_literal, Literal::Numeric),
        map(string_literal, Literal::String),
        map(logical_literal, Literal::Logical),
    )))(i)
}
pub(super) fn numeric_literal(i: &str) -> IResult<&str, String> {
    fn exp_part(i: &str) -> IResult<&str, &str> {
        recognize(tuple((one_of("eE"), opt(tag("-")), digit1)))(i)
    }
    fn frac_part(i: &str) -> IResult<&str, &str> {
        recognize(tuple((tag("."), digit0)))(i)
    }
    spaced(map(
        recognize(tuple((opt(tag("-")), alt((digit1, tag("inf"))), opt(frac_part), opt(exp_part)))),
        |s: &str| s.to_owned(),
    ))(i)
}
fn string_literal(i: &str) -> IResult<&str, String> {
    fn inner(i: &str) -> IResult<&str, String> {
        map(
            many0(alt((
                preceded(tag("\\"), nom::character::complete::anychar),
                nom::character::complete::none_of("\\\"'"),
            ))),
            |v: Vec<char>| v.into_iter().collect(),
        )(i)
    }
    map(alt((delimited(tag("'"), inner, tag("'")), delimited(tag("\""), inner, tag("\"")))), |s| s)(
        i,
    )
}
pub(super) fn logical_literal(i: &str) -> IResult<&str, bool> {
    spaced(alt((map(tag("true"), |_| true), map(tag("false"), |_| false))))(i)
}
fn space_and_comments(i: &str) -> IResult<&str, ()> {
    map(
        many0(alt((
            recognize(one_of(" \t\n\r")),
            recognize(tuple((tag("#"), many0(none_of("\r\n"))))),
        ))),
        |_| (),
    )(i)
}
fn spaced<'s, O, F>(it: F) -> impl FnMut(&'s str) -> IResult<&'s str, O>
where
    F: FnMut(&'s str) -> IResult<&'s str, O>,
{
    delimited(space_and_comments, it, space_and_comments)
}
pub(super) fn stag<'s>(t: &'static str) -> impl FnMut(&'s str) -> IResult<&'s str, &'s str> {
    spaced(tag(t))
}
#[cfg(test)]
mod test {
    use super::*;
    use TypeName::*;
    use TypeSpec::*;
    fn p<'s, P, O, E>(parser: P, i: &'s str) -> O
    where
        O: std::fmt::Debug,
        P: Fn(&'s str) -> IResult<&'s str, O, E>,
        E: nom::error::ParseError<&'s str> + std::fmt::Debug,
    {
        let res = all_consuming(parser)(i).unwrap();
        res.1
    }
    fn param(s: impl Into<std::string::String>, t: TypeSpec) -> Parameter {
        Parameter { id: Identifier(s.into()), spec: t, lit: None, doc: None }
    }
    fn result(s: impl Into<std::string::String>, t: TypeSpec) -> Result_ {
        Result_ { id: Identifier(s.into()), spec: t }
    }
    #[test]
    fn test_type_spec() {
        assert_eq!(p(type_spec, "scalar"), Single(Scalar));
        assert_eq!(p(type_spec, "scalar[]"), Array(Box::new(Single(Scalar))));
        assert_eq!(p(type_spec, "tensor<scalar>[]"), Array(Box::new(Tensor(TypeName::Scalar))));
        assert_eq!(
            p(type_spec, "(scalar,scalar[],tensor<scalar>)"),
            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
        );
        assert_eq!(p(type_spec, "tensor<?>[]"), Array(Box::new(Tensor(TypeName::Any))));
        assert_eq!(p(type_spec, "scalar[ ]"), Array(Box::new(Single(Scalar))));
        assert_eq!(
            p(type_spec, " ( scalar , scalar [ ] , tensor < scalar > ) "),
            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
        );
        #[cfg(feature = "complex")]
        assert_eq!(p(type_spec, "tensor<complex>[]"), Array(Box::new(Tensor(TypeName::Complex))));
    }
    #[test]
    fn test_fragment_decl_fizz() {
        let parsed = p(
            fragment_decl,
            "fragment fizz<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
        );
        assert_eq!(
            parsed,
            FragmentDecl {
                id: "fizz".into(),
                generic_decl: Some(Some(Scalar)),
                parameters: vec!(param("shape", Array(Box::new(Single(Integer)))),),
                results: vec!(result("output", Tensor(Any))),
            }
        );
    }
    #[test]
    fn test_fragment_decl_logarithmic_quantize() {
        let parsed = p(fragment_decl,
                           "fragment logarithmic_quantize(x: tensor<scalar>, max: tensor<scalar>, bits: integer ) -> ( y: tensor<scalar> )"
                          );
        assert_eq!(
            parsed,
            FragmentDecl {
                id: "logarithmic_quantize".into(),
                generic_decl: None,
                parameters: vec!(
                    param("x", Tensor(Scalar)),
                    param("max", Tensor(Scalar)),
                    param("bits", Single(Integer))
                ),
                results: vec!(result("y", Tensor(Scalar))),
            }
        );
    }
    #[test]
    fn test_fragment_decl_external() {
        p(
            fragment_decl,
            "fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
        );
    }
    #[test]
    fn test_fragment_reshape() {
        p(fragments, "fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );");
    }
    #[test]
    fn test_fragment_conv() {
        p(
            fragments,
            r#"
            fragment conv(
                input: tensor<scalar>,
                filter: tensor<scalar>,
                bias: tensor<scalar> = 0.0,
                border: string = 'constant',
                padding: (integer,integer)[] = [],
                stride: integer[] = [],
                dilation: integer[] = [],
                groups: integer = 1 )
            -> ( output: tensor<scalar> );
            "#,
        );
    }
    #[test]
    fn test_fragment_local_response_normalization() {
        p(
            fragments,
            r#"
            fragment local_response_normalization(
                input: tensor<scalar>,
                size: integer[],
                alpha: scalar = 1.0,
                beta: scalar = 0.5,
                bias: scalar = 1.0 )
            -> ( output: tensor<scalar> )
            {
                sigma = bias + alpha * box(sqr(input), size = size, normalize = true);
                output = input / (sigma ^ beta);
            }
            "#,
        );
    }
    #[test]
    fn test_batch_normalization() {
        p(
            fragments,
            r#"
            fragment batch_normalization( input: tensor<scalar>, mean: tensor<scalar>, variance: tensor<scalar>, offset: tensor<scalar>, scale: tensor<scalar>, epsilon: scalar )
            -> ( output: tensor<scalar> )
            {
                output = offset + scale * (input - mean) / sqrt(variance + epsilon);
            }
            "#,
        );
    }
    #[test]
    fn test_avg_roi_align() {
        p(
            fragments,
            r#"
                fragment avg_roi_align(
                    input: tensor<scalar>,
                    rois: tensor<scalar>,
                    batch_index: tensor<integer>,
                    output_size: integer[],
                    sampling_rate: integer[],
                    resize_method: string = 'symmetric' )
                -> ( output: tensor<scalar> )
                {
                    size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];
                    resized = roi_resample(input, rois, batch_index, output_size = size,
                                         method = resize_method);
                    output = avg_pool(resized, size = sampling_rate, stride = sampling_rate);
                }
            "#,
        );
    }
    #[test]
    fn test_min_max_linear_quantize() {
        p(
            fragments,
            r#"
                fragment min_max_linear_quantize(
                    x: tensor<scalar>,
                    min: tensor<scalar>,
                    max: tensor<scalar>,
                    bits: integer,
                    signed: logical,
                    symmetric: logical )
                -> ( y: tensor<scalar> )
                {
                    r = scalar(2 ^ bits - 1 - integer(signed && symmetric));
                    z = clamp(x, min, max);
                    p = scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0);
                    q = round((z - min) / (max - min) * r) - p;
                    y = (q + p) / r * (max - min) + min;
}
            "#,
        );
    }
    #[test]
    fn test_numeric() {
        p(numeric_literal, "12.0");
    }
    #[test]
    fn test_string() {
        assert_eq!(p(string_literal, r#""""#), "");
        assert_eq!(p(string_literal, r#""foo""#), "foo");
        assert_eq!(p(string_literal, r#"''"#), "");
        assert_eq!(p(string_literal, r#"'foo'"#), "foo");
        assert_eq!(p(string_literal, r"'f\oo'"), "foo");
        assert_eq!(p(string_literal, r"'f\'oo'"), "f'oo");
        assert_eq!(p(string_literal, r#"'f\"oo'"#), "f\"oo");
    }
    #[test]
    fn test_identifier() {
        p(identifier, "foo");
        assert!(identifier("1").is_err());
        assert!(identifier("1foo").is_err());
    }
    #[test]
    fn test_spacing() {
        p(space_and_comments, "");
        p(space_and_comments, "\n");
        p(space_and_comments, "#comment\n");
        p(space_and_comments, "#boum");
    }
    #[test]
    fn test_spaced() {
        assert!(spaced(identifier)("foo").is_ok());
        assert!(spaced(identifier)(" foo ").is_ok());
        assert!(many1(spaced(identifier))(" foo bar ").is_ok());
        assert_eq!(
            many1(spaced(identifier))(" foo bar\n").unwrap().1,
            &[Identifier("foo".to_string()), Identifier("bar".to_string())]
        );
        assert_eq!(
            many1(spaced(identifier))(" foo # bar\n").unwrap().1,
            &[Identifier("foo".to_string())]
        );
        assert_eq!(
            many1(spaced(identifier))(" foo # bar\nbaz").unwrap().1,
            &[Identifier("foo".to_string()), Identifier("baz".to_string())]
        );
    }
    #[test]
    fn test_document() {
        assert!(document("version 1.0; graph foo() -> () {}").is_ok());
    }
    #[test]
    fn test_version() {
        p(version, "version 1.0;");
    }
    #[test]
    fn test_body() {
        p(body, "{}");
        p(body, "{foo=bar;}");
    }
    #[test]
    fn test_lvalue() {
        p(lvalue, "foo");
        p(lvalue, "foo,bar");
        p(lvalue, "foo , bar");
        p(lvalue, "(foo,bar)");
    }
    #[test]
    fn test_graph_def() {
        p(graph_def, "graph foo() -> () {}");
    }
    #[test]
    fn test_assignment() {
        p(assignment, "input = external(12);");
        p(assignment, "input = external(shape = [1, 3, 224, 224]);");
        p(assignment, "sigma = bias + alpha * box(sqr(input), size = size, normalize = true);");
        p(assignment, "output = offset + scale * (input - mean) / sqrt(variance + epsilon);");
        p(
            assignment,
            "size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];",
        );
        p(assignment, "r = scalar(2 ^ bits - 1 - integer(signed && symmetric));");
        p(assignment, "output, index = max_pool_with_index(input, size = size, border = border, padding = padding, stride = stride, dilation = dilation);");
    }
    #[test]
    fn test_invocation() {
        p(invocation, "external(12)");
        p(invocation, "sqrt(var + eps)");
    }
    #[test]
    fn test_arguments() {
        p(argument, "2");
        p(argument, "12");
        p(argument, "shape = [1, 3, 224, 224]");
    }
    #[test]
    fn test_rvalue() {
        p(rvalue, "12");
        p(rvalue, "(0, 0)");
        p(rvalue, "x ^ 2.0");
        p(rvalue, "1+2");
        p(rvalue, "1+sqrt(var)");
        p(rvalue, "1+sqrt(var+eps)");
        p(rvalue, "1 + sqrt(var + eps)");
        p(rvalue, "[for i in range_of(output_size) yield output_size[i] * sampling_rate[i]]");
        p(rvalue, "scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0)");
    }
    #[test]
    fn test_comprehenion() {
        p(comprehension_expr, "[for i in range_of(output_size) yield output_size * sampling_rate]");
    }
    #[test]
    fn test_freeze() {
        p(
            document,
            r#"
version 1.0;
graph y( x, s, bias ) -> ( y ) {
  x = external<scalar>(shape = [1, 2, 1, 3]);
  s = external<scalar>(shape = [2]);
  bias = external<scalar>(shape = [2]);
  y = add(
        mul(
            mul(
                sub(
                    x,
                    mul(
                        0.33333334,
                        sum_reduce(
                            x,
                            axes = [0, 2, 3]
                        )
                    )
                ),
                rsqrt(
                    add(
                        0.00001,
                        mul(
                            0.33333334,
                            sum_reduce(
                                square(
                                    sub(
                                        x,
                                        mul(
                                            0.33333334,
                                            sum_reduce(
                                                x,
                                                axes = [0, 2, 3]
                                            )
                                        )
                                    )
                                ),
                                axes = [0, 2, 3]
                            )
                        )
                    )
                )
            ),
            unsqueeze(
                unsqueeze(
                    unsqueeze(
                        s,
                        axes = [0]
                    ),
                axes = [2]
                ),
            axes = [2]
            )
        ),
        unsqueeze(
            unsqueeze(
                unsqueeze(
                    bias,
                    axes = [0]
                ),
                axes = [2]
            ),
            axes = [2]
        )
    );
}
"#,
        );
    }
    #[test]
    fn test_fragments() {
        p(
            fragments,
            r#"
            fragment add( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
            fragment sub( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
            "#,
        );
    }
}