1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
use crate::pb::tensor_proto::DataType; use crate::pb::*; use prost::Message; use std::convert::{TryFrom, TryInto}; use tract_core::internal::*; use tract_core::*; impl TryFrom<DataType> for DatumType { type Error = TractError; fn try_from(t: DataType) -> TractResult<DatumType> { match t { DataType::Bool => Ok(DatumType::Bool), DataType::Uint8 => Ok(DatumType::U8), DataType::Uint16 => Ok(DatumType::U16), DataType::Int8 => Ok(DatumType::I8), DataType::Int16 => Ok(DatumType::I16), DataType::Int32 => Ok(DatumType::I32), DataType::Int64 => Ok(DatumType::I64), DataType::Float16 => Ok(DatumType::F16), DataType::Float => Ok(DatumType::F32), DataType::Double => Ok(DatumType::F64), DataType::String => Ok(DatumType::String), _ => Err(format!("Unknown DatumType {:?}", t))?, } } } impl<'a> TryFrom<&'a type_proto::Tensor> for InferenceFact { type Error = TractError; fn try_from(t: &'a type_proto::Tensor) -> TractResult<InferenceFact> { let mut fact = InferenceFact::default(); fact = fact.with_datum_type(DataType::from_i32(t.elem_type).unwrap().try_into()?); if let Some(shape) = &t.shape { let shape: TVec<DimFact> = shape .dim .iter() .map(|d| { let mut fact = DimFact::default(); if let Some(tensor_shape_proto::dimension::Value::DimValue(v)) = d.value { if v > 0 { fact = DimFact::from(v.to_dim()) } } fact }) .collect(); fact = fact.with_shape(ShapeFact::closed(shape)); } Ok(fact) } } impl TryFrom<type_proto::Tensor> for InferenceFact { type Error = TractError; fn try_from(t: type_proto::Tensor) -> TractResult<InferenceFact> { (&t).try_into() } } impl<'a> TryFrom<&'a TensorProto> for Tensor { type Error = TractError; fn try_from(t: &TensorProto) -> TractResult<Tensor> { let dt = DataType::from_i32(t.data_type).unwrap().try_into()?; let shape: Vec<usize> = t.dims.iter().map(|&i| i as usize).collect(); if t.raw_data.len() > 0 { unsafe { match dt { DatumType::U8 => Tensor::from_raw::<u8>(&*shape, &*t.raw_data), DatumType::U16 => Tensor::from_raw::<u16>(&*shape, &*t.raw_data), DatumType::I8 => Tensor::from_raw::<i8>(&*shape, &*t.raw_data), DatumType::I16 => Tensor::from_raw::<i16>(&*shape, &*t.raw_data), DatumType::I32 => Tensor::from_raw::<i32>(&*shape, &*t.raw_data), DatumType::I64 => Tensor::from_raw::<i64>(&*shape, &*t.raw_data), DatumType::F16 => Tensor::from_raw::<f16>(&*shape, &*t.raw_data), DatumType::F32 => Tensor::from_raw::<f32>(&*shape, &*t.raw_data), DatumType::F64 => Tensor::from_raw::<f64>(&*shape, &*t.raw_data), DatumType::Bool => Ok(Tensor::from_raw::<u8>(&*shape, &*t.raw_data)? .into_array::<u8>()? .mapv(|x| x != 0) .into()), _ => unimplemented!("FIXME, raw tensor loading"), } } } else { use ndarray::Array; let it = match dt { DatumType::Bool => { Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x != 0).collect())? .into() } DatumType::U8 => { Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u8).collect())? .into() } DatumType::U16 => Array::from_shape_vec( &*shape, t.int32_data.iter().map(|&x| x as u16).collect(), )? .into(), DatumType::I8 => { Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i8).collect())? .into() } DatumType::I16 => Array::from_shape_vec( &*shape, t.int32_data.iter().map(|&x| x as i16).collect(), )? .into(), DatumType::I32 => { Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into() } DatumType::I64 => { Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into() } DatumType::F32 => { Array::from_shape_vec(&*shape, t.float_data.to_vec())?.into() } DatumType::F64 => { Array::from_shape_vec(&*shape, t.double_data.to_vec())?.into() } DatumType::String => { let strings = t .string_data .iter() .cloned() .map(String::from_utf8) .collect::<Result<Vec<String>, _>>() .map_err(|_| format!("Invalid UTF8 buffer"))?; Array::from_shape_vec(&*shape, strings)?.into() } _ => unimplemented!("FIXME, struct tensor loading"), }; Ok(it) } } } impl TryFrom<TensorProto> for Tensor { type Error = TractError; fn try_from(t: TensorProto) -> TractResult<Tensor> { (&t).try_into() } } pub fn proto_from_reader<R: ::std::io::Read>(mut r: R) -> TractResult<TensorProto> { let mut v = vec![]; r.read_to_end(&mut v)?; TensorProto::decode(v).map_err(|e| format!("Can not parse protobuf input: {:?}", e).into()) } pub fn from_reader<R: ::std::io::Read>(r: R) -> TractResult<Tensor> { proto_from_reader(r)?.try_into() }