1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
use crate::tfpb::tensor::TensorProto; use crate::tfpb::tensor_shape::{TensorShapeProto, TensorShapeProto_Dim}; use crate::tfpb::types::DataType; use std::convert::{TryFrom, TryInto}; use tract_core::internal::*; impl TryFrom<DataType> for DatumType { type Error = TractError; fn try_from(t: DataType) -> TractResult<DatumType> { match t { DataType::DT_BOOL => Ok(DatumType::Bool), DataType::DT_UINT8 => Ok(DatumType::U8), DataType::DT_UINT16 => Ok(DatumType::U16), DataType::DT_INT8 => Ok(DatumType::I8), DataType::DT_INT16 => Ok(DatumType::I16), DataType::DT_INT32 => Ok(DatumType::I32), DataType::DT_INT64 => Ok(DatumType::I64), DataType::DT_HALF => Ok(DatumType::F16), DataType::DT_FLOAT => Ok(DatumType::F32), DataType::DT_DOUBLE => Ok(DatumType::F64), DataType::DT_STRING => Ok(DatumType::String), _ => Err(format!("Unknown DatumType {:?}", t))?, } } } impl<'a> TryFrom<&'a TensorShapeProto> for TVec<usize> { type Error = TractError; fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<usize>> { Ok(t.get_dim().iter().map(|d| d.size as usize).collect::<TVec<_>>()) } } impl TryFrom<DatumType> for DataType { type Error = TractError; fn try_from(dt: DatumType) -> TractResult<DataType> { match dt { DatumType::Bool => Ok(DataType::DT_BOOL), DatumType::U8 => Ok(DataType::DT_UINT8), DatumType::U16 => Ok(DataType::DT_UINT16), DatumType::I8 => Ok(DataType::DT_INT8), DatumType::I16 => Ok(DataType::DT_INT16), DatumType::I32 => Ok(DataType::DT_INT32), DatumType::I64 => Ok(DataType::DT_INT64), DatumType::F16 => Ok(DataType::DT_HALF), DatumType::F32 => Ok(DataType::DT_FLOAT), DatumType::F64 => Ok(DataType::DT_DOUBLE), DatumType::String => Ok(DataType::DT_STRING), DatumType::TDim => bail!("Dimension is not translatable in protobuf"), } } } impl<'a> TryFrom<&'a TensorProto> for Tensor { type Error = TractError; fn try_from(t: &TensorProto) -> TractResult<Tensor> { let dtype = t.get_dtype(); let dims: TVec<usize> = t.get_tensor_shape().try_into()?; let rank = dims.len(); let content = t.get_tensor_content(); let mat: Tensor = if content.len() != 0 { unsafe { match dtype { DataType::DT_FLOAT => Self::from_raw::<f32>(&dims, content)?, DataType::DT_INT32 => Self::from_raw::<i32>(&dims, content)?, DataType::DT_INT64 => Self::from_raw::<i64>(&dims, content)?, _ => unimplemented!("missing type {:?}", dtype), } } } else { use ndarray::Array; match dtype { DataType::DT_INT32 => { Array::from_shape_vec(&*dims, t.get_int_val().to_vec())?.into() } DataType::DT_INT64 => { Array::from_shape_vec(&*dims, t.get_int64_val().to_vec())?.into() } DataType::DT_FLOAT => { Array::from_shape_vec(&*dims, t.get_float_val().to_vec())?.into() } _ => unimplemented!("missing type {:?}", dtype), } }; assert_eq!(rank, mat.shape().len()); Ok(mat) } } impl<'a> TryFrom<&'a Tensor> for TensorProto { type Error = TractError; fn try_from(from: &Tensor) -> TractResult<TensorProto> { let mut shape = TensorShapeProto::new(); let dims = from .shape() .iter() .map(|d| { let mut dim = TensorShapeProto_Dim::new(); dim.size = *d as _; dim }) .collect(); shape.set_dim(::protobuf::RepeatedField::from_vec(dims)); let mut tensor = TensorProto::new(); tensor.set_tensor_shape(shape); match from.datum_type() { DatumType::F32 => { tensor.set_dtype(DatumType::F32.try_into()?); tensor.set_float_val(from.to_array_view::<f32>()?.iter().cloned().collect()); } DatumType::F64 => { tensor.set_dtype(DatumType::F64.try_into()?); tensor.set_double_val(from.to_array_view::<f64>()?.iter().cloned().collect()); } DatumType::I32 => { tensor.set_dtype(DatumType::I32.try_into()?); tensor.set_int_val(from.to_array_view::<i32>()?.iter().cloned().collect()); } DatumType::I64 => { tensor.set_dtype(DatumType::I64.try_into()?); tensor.set_int64_val(from.to_array_view::<i64>()?.iter().cloned().collect()); } _ => unimplemented!("missing type {:?}", from.datum_type()), } Ok(tensor) } }