1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use crate::tfpb::tensor::TensorProto;
use crate::tfpb::tensor_shape::{TensorShapeProto, TensorShapeProto_Dim};
use crate::tfpb::types::DataType;
use std::convert::{TryFrom, TryInto};
use tract_core::internal::*;

impl TryFrom<DataType> for DatumType {
    type Error = TractError;
    fn try_from(t: DataType) -> TractResult<DatumType> {
        match t {
            DataType::DT_BOOL => Ok(DatumType::Bool),
            DataType::DT_UINT8 => Ok(DatumType::U8),
            DataType::DT_UINT16 => Ok(DatumType::U16),
            DataType::DT_INT8 => Ok(DatumType::I8),
            DataType::DT_INT16 => Ok(DatumType::I16),
            DataType::DT_INT32 => Ok(DatumType::I32),
            DataType::DT_INT64 => Ok(DatumType::I64),
            DataType::DT_HALF => Ok(DatumType::F16),
            DataType::DT_FLOAT => Ok(DatumType::F32),
            DataType::DT_DOUBLE => Ok(DatumType::F64),
            DataType::DT_STRING => Ok(DatumType::String),
            _ => Err(format!("Unknown DatumType {:?}", t))?,
        }
    }
}

impl<'a> TryFrom<&'a TensorShapeProto> for TVec<usize> {
    type Error = TractError;
    fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<usize>> {
        Ok(t.get_dim().iter().map(|d| d.size as usize).collect::<TVec<_>>())
    }
}

impl TryFrom<DatumType> for DataType  {
    type Error = TractError;
    fn try_from(dt: DatumType) -> TractResult<DataType> {
        match dt {
            DatumType::Bool => Ok(DataType::DT_BOOL),
            DatumType::U8 => Ok(DataType::DT_UINT8),
            DatumType::U16 => Ok(DataType::DT_UINT16),
            DatumType::I8 => Ok(DataType::DT_INT8),
            DatumType::I16 => Ok(DataType::DT_INT16),
            DatumType::I32 => Ok(DataType::DT_INT32),
            DatumType::I64 => Ok(DataType::DT_INT64),
            DatumType::F16 => Ok(DataType::DT_HALF),
            DatumType::F32 => Ok(DataType::DT_FLOAT),
            DatumType::F64 => Ok(DataType::DT_DOUBLE),
            DatumType::String => Ok(DataType::DT_STRING),
            DatumType::TDim => bail!("Dimension is not translatable in protobuf"),
        }
    }
}

impl<'a> TryFrom<&'a TensorProto> for Tensor {
    type Error = TractError;
    fn try_from(t: &TensorProto) -> TractResult<Tensor> {
        let dtype = t.get_dtype();
        let dims: TVec<usize> = t.get_tensor_shape().try_into()?;
        let rank = dims.len();
        let content = t.get_tensor_content();
        let mat: Tensor = if content.len() != 0 {
            unsafe {
                match dtype {
                    DataType::DT_FLOAT => Self::from_raw::<f32>(&dims, content)?,
                    DataType::DT_INT32 => Self::from_raw::<i32>(&dims, content)?,
                    DataType::DT_INT64 => Self::from_raw::<i64>(&dims, content)?,
                    _ => unimplemented!("missing type {:?}", dtype),
                }
            }
        } else {
            use ndarray::Array;
            match dtype {
                DataType::DT_INT32 => {
                    Array::from_shape_vec(&*dims, t.get_int_val().to_vec())?.into()
                }
                DataType::DT_INT64 => {
                    Array::from_shape_vec(&*dims, t.get_int64_val().to_vec())?.into()
                }
                DataType::DT_FLOAT => {
                    Array::from_shape_vec(&*dims, t.get_float_val().to_vec())?.into()
                }
                _ => unimplemented!("missing type {:?}", dtype),
            }
        };
        assert_eq!(rank, mat.shape().len());
        Ok(mat)
    }
}

impl<'a> TryFrom<&'a Tensor> for TensorProto {
    type Error = TractError;
    fn try_from(from: &Tensor) -> TractResult<TensorProto> {
        let mut shape = TensorShapeProto::new();
        let dims = from
            .shape()
            .iter()
            .map(|d| {
                let mut dim = TensorShapeProto_Dim::new();
                dim.size = *d as _;
                dim
            })
            .collect();
        shape.set_dim(::protobuf::RepeatedField::from_vec(dims));
        let mut tensor = TensorProto::new();
        tensor.set_tensor_shape(shape);
        match from.datum_type() {
            DatumType::F32 => {
                tensor.set_dtype(DatumType::F32.try_into()?);
                tensor.set_float_val(from.to_array_view::<f32>()?.iter().cloned().collect());
            }
            DatumType::F64 => {
                tensor.set_dtype(DatumType::F64.try_into()?);
                tensor.set_double_val(from.to_array_view::<f64>()?.iter().cloned().collect());
            }
            DatumType::I32 => {
                tensor.set_dtype(DatumType::I32.try_into()?);
                tensor.set_int_val(from.to_array_view::<i32>()?.iter().cloned().collect());
            }
            DatumType::I64 => {
                tensor.set_dtype(DatumType::I64.try_into()?);
                tensor.set_int64_val(from.to_array_view::<i64>()?.iter().cloned().collect());
            }
            _ => unimplemented!("missing type {:?}", from.datum_type()),
        }
        Ok(tensor)
    }
}