1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//! Tensorflow Lite Op Resolvers
//!

use crate::bindings::tflite;

use core::fmt;

cpp! {{
    #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
    #include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
}}

// AllOpsResolver has the same memory representation as
// MicroMutableOpResolver<128>.
//
// That is:
// class AllOpsResolver : public MicroMutableOpResolver<128> { ... }
//
// Thus we can cast between the two types.

type OpResolverT = tflite::ops::micro::AllOpsResolver;

/// Marker trait for types that have the memory representation of a
/// `OpResolver`
pub trait OpResolverRepr {
    fn to_inner(self) -> OpResolverT;
}

/// An Op Resolver populated with all available operators
#[derive(Default)]
pub struct AllOpResolver(OpResolverT);
impl OpResolverRepr for AllOpResolver {
    fn to_inner(self) -> OpResolverT {
        self.0
    }
}

/// An Op Resolver that has no operators by default, but can be added by
/// calling methods in a builder pattern
#[derive(Default)]
pub struct MutableOpResolver {
    pub(crate) inner: OpResolverT,
    capacity: usize,
    len: usize,
}
impl OpResolverRepr for MutableOpResolver {
    fn to_inner(self) -> OpResolverT {
        self.inner
    }
}
impl fmt::Debug for MutableOpResolver {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_fmt(format_args!("MutableOpResolver (ops = {})", self.len))
    }
}

impl AllOpResolver {
    /// Create a new Op Resolver, populated with all available
    /// operators
    pub fn new() -> Self {
        // The C++ compiler fills in the MicroMutableOpResolver with the
        // operators enumerated in AllOpsResolver
        let micro_op_resolver = unsafe {
            cpp!([] -> OpResolverT as "tflite::ops::micro::AllOpsResolver" {
                // All ops
                tflite::ops::micro::AllOpsResolver resolver;

                return resolver;
            })
        };

        Self(micro_op_resolver)
    }
}

impl MutableOpResolver {
    /// Check the number of operators is OK
    pub(crate) fn check_then_inc_len(&mut self) {
        assert!(
            self.len < self.capacity,
            "Tensorflow micro does not support more than {} operators.",
            self.capacity
        );

        self.len += 1;
    }

    /// Returns the current number of operators in this resolver
    pub fn len(&self) -> usize {
        self.len
    }

    /// Return whether there are zero operators
    pub fn is_empty(&self) -> bool {
        self.len == 0
    }

    /// Create a new MutableOpResolver, initially empty
    pub fn empty() -> Self {
        // Maximum number of registrations
        //
        // tensorflow/lite/micro/kernels/all_ops_resolver.h:L27
        let tflite_registrations_max = 128;

        let micro_op_resolver = unsafe {
            // Create resolver object
            //
            // We still need to take the full memory footprint of
            // `MicroMutableOpResolver`, in order to be layout
            // compatible. However the unreferenced operations themselves will
            // be optimised away
            cpp!([] -> OpResolverT as
                 "tflite::MicroMutableOpResolver<128>" {

                tflite::MicroMutableOpResolver<128> resolver;
                return resolver;
            })
        };

        Self {
            inner: micro_op_resolver,
            capacity: tflite_registrations_max,
            len: 0,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn all_ops_resolver() {
        let _ = AllOpResolver::new();
    }

    #[test]
    fn mutable_op_resolver() {
        let _ = MutableOpResolver::empty()
            .depthwise_conv_2d()
            .fully_connected()
            .softmax();
    }
}