1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#[macro_export]
macro_rules! op_hir {
    () => {
        fn op_families(&self) -> &'static [&'static str] {
            &["core"]
        }
    };
}

#[macro_export]
macro_rules! to_typed {
    () => {
        fn to_typed(
            &self,
            _source: &$crate::infer::InferenceModel,
            node: &$crate::infer::InferenceNode,
            target: &mut TypedModel,
            mapping: &std::collections::HashMap<OutletId, OutletId>,
        ) -> TractResult<TVec<OutletId>> {
            let inputs = node.inputs.iter().map(|m| mapping[m]).collect::<TVec<_>>();
            target.wire_node(&*node.name, self.clone(), &*inputs)
        }
    };
}

/// Constructs a type fact.
#[macro_export]
macro_rules! typefact {
    (_) => {
        $crate::infer::TypeFactoid::default()
    };
    ($arg:expr) => {{
        let fact: $crate::infer::TypeFactoid = $crate::infer::GenericFactoid::Only($arg);
        fact
    }};
}

/// Constructs a shape fact.
#[macro_export]
macro_rules! shapefactoid {
    () =>
        ($crate::infer::ShapeFactoid::closed(tvec![]));
    (..) =>
        ($crate::infer::ShapeFactoid::open(tvec![]));
    ($($arg:tt),+; ..) =>
        ($crate::infer::ShapeFactoid::open(tvec![$($crate::dimfact!($arg)),+]));
    ($($arg:tt),+) =>
        ($crate::infer::ShapeFactoid::closed(tvec![$($crate::dimfact!($arg)),+]));
}

/// Constructs a dimension fact.
#[macro_export]
macro_rules! dimfact {
    (_) => {
        $crate::infer::DimFact::default()
    };
    (S) => {
        $crate::infer::GenericFactoid::Only(tract_pulse::internal::stream_dim())
    };
    ($arg:expr) => {
        $crate::infer::GenericFactoid::Only($arg.to_dim())
    };
}

/// Constructs an value fact.
#[macro_export]
macro_rules! valuefact {
    (_) => {
        $crate::infer::ValueFact::default()
    };
    ($arg:expr) => {{
        let fact: $crate::infer::ValueFact = $crate::infer::GenericFactoid::Only($arg);
        fact
    }};
}

/// Tries to unwrap an option, or returns Ok(None) otherwise.
#[macro_export]
macro_rules! unwrap_or_none {
    ($e:expr) => {{
        let e = $e;
        if e.is_none() {
            return Ok(None);
        } else {
            e.unwrap()
        }
    }};
}

#[cfg(tests)]
mod tests {
    #[test]
    fn shape_macro_closed_1() {
        assert_eq!(shapefactoid![], ShapeFactoid::closed(tvec![]));
    }

    #[test]
    fn shape_macro_closed_2() {
        assert_eq!(shapefactoid![1], ShapeFactoid::closed(tvec![GenericFactoid::Only(1)]));
    }

    #[test]
    fn shape_macro_closed_3() {
        assert_eq!(shapefactoid![(1 + 1)], ShapeFactoid::closed(vec![GenericFactoid::Only(2)]));
    }

    #[test]
    fn shape_macro_closed_4() {
        assert_eq!(
            shapefactoid![_, 2],
            ShapeFactoid::closed(vec![GenericFactoid::Any, GenericFactoid::Only(2)])
        );
    }

    #[test]
    fn shape_macro_closed_5() {
        assert_eq!(
            shapefactoid![(1 + 1), _, 2],
            ShapeFactoid::closed(vec![
                GenericFactoid::Only(2),
                GenericFactoid::Any,
                GenericFactoid::Only(2),
            ])
        );
    }

    #[test]
    fn shape_macro_open_1() {
        assert_eq!(shapefactoid![..], ShapeFactoid::open(tvec![]));
    }

    #[test]
    fn shape_macro_open_2() {
        assert_eq!(shapefactoid![1; ..], ShapeFactoid::open(vec![GenericFactoid::Only(1)]));
    }

    #[test]
    fn shape_macro_open_3() {
        assert_eq!(shapefactoid![(1 + 1); ..], ShapeFactoid::open(vec![GenericFactoid::Only(2)]));
    }

    #[test]
    fn shape_macro_open_4() {
        assert_eq!(
            shapefactoid![_, 2; ..],
            ShapeFactoid::open(vec![GenericFactoid::Any, GenericFactoid::Only(2)])
        );
    }

    #[test]
    fn shape_macro_open_5() {
        assert_eq!(
            shapefactoid![(1 + 1), _, 2; ..],
            ShapeFactoid::open(tvec![
                GenericFactoid::Only(2),
                GenericFactoid::Any,
                GenericFactoid::Only(2),
            ])
        );
    }
}