tract_tensorflow/
tensor.rs1use tract_hir::internal::*;
2
3use crate::tfpb::tensorflow::tensor_shape_proto::Dim;
4use crate::tfpb::tensorflow::{TensorProto, TensorShapeProto};
5
6use crate::tfpb::tensorflow::DataType;
7use std::convert::TryFrom;
8
9impl TryFrom<DataType> for DatumType {
10 type Error = TractError;
11 fn try_from(t: DataType) -> TractResult<DatumType> {
12 match t {
13 DataType::DtBool => Ok(DatumType::Bool),
14 DataType::DtUint8 => Ok(DatumType::U8),
15 DataType::DtUint16 => Ok(DatumType::U16),
16 DataType::DtUint32 => Ok(DatumType::U32),
17 DataType::DtUint64 => Ok(DatumType::U64),
18 DataType::DtInt8 => Ok(DatumType::I8),
19 DataType::DtInt16 => Ok(DatumType::I16),
20 DataType::DtInt32 => Ok(DatumType::I32),
21 DataType::DtInt64 => Ok(DatumType::I64),
22 DataType::DtHalf => Ok(DatumType::F16),
23 DataType::DtFloat => Ok(DatumType::F32),
24 DataType::DtDouble => Ok(DatumType::F64),
25 DataType::DtString => Ok(DatumType::Blob),
26 _ => Err(format_err!("Unknown DatumType {:?}", t))?,
27 }
28 }
29}
30
31impl<'a> TryFrom<&'a TensorShapeProto> for TVec<isize> {
32 type Error = TractError;
33 fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<isize>> {
34 Ok(t.dim.iter().map(|d| d.size as isize).collect::<TVec<_>>())
35 }
36}
37
38impl<'a> TryFrom<&'a TensorShapeProto> for TVec<usize> {
39 type Error = TractError;
40 fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<usize>> {
41 if t.dim.iter().any(|d| d.size < 0) {
42 bail!("Negative dim found")
43 }
44 Ok(t.dim.iter().map(|d| d.size as usize).collect::<TVec<_>>())
45 }
46}
47
48impl TryFrom<DatumType> for DataType {
49 type Error = TractError;
50 fn try_from(dt: DatumType) -> TractResult<DataType> {
51 match dt {
52 DatumType::Bool => Ok(DataType::DtBool),
53 DatumType::U8 => Ok(DataType::DtUint8),
54 DatumType::U16 => Ok(DataType::DtUint16),
55 DatumType::U32 => Ok(DataType::DtUint32),
56 DatumType::U64 => Ok(DataType::DtUint64),
57 DatumType::I8 => Ok(DataType::DtInt8),
58 DatumType::I16 => Ok(DataType::DtInt16),
59 DatumType::I32 => Ok(DataType::DtInt32),
60 DatumType::I64 => Ok(DataType::DtInt64),
61 DatumType::F16 => Ok(DataType::DtHalf),
62 DatumType::F32 => Ok(DataType::DtFloat),
63 DatumType::F64 => Ok(DataType::DtDouble),
64 DatumType::Blob => Ok(DataType::DtString),
65 DatumType::String => Ok(DataType::DtString),
66 DatumType::QI8(_) => Ok(DataType::DtQint8),
67 DatumType::QU8(_) => Ok(DataType::DtQint8),
68 DatumType::QI32(_) => Ok(DataType::DtQint32),
69 _ => bail!("DatumType is not translatable in protobuf"),
70 }
71 }
72}
73
74fn tensor_from_repeated_field<T: Datum>(shape: &[usize], data: Vec<T>) -> TractResult<Tensor> {
75 let t = if data.len() == 1 {
76 tract_ndarray::ArrayD::from_elem(shape, data[0].clone()).into()
77 } else {
78 tract_ndarray::ArrayD::from_shape_vec(shape, data.to_vec())?.into()
79 };
80 Ok(t)
81}
82
83impl TryFrom<&TensorProto> for Tensor {
84 type Error = TractError;
85 fn try_from(t: &TensorProto) -> TractResult<Tensor> {
86 let dims: TVec<usize> =
87 t.tensor_shape.as_ref().unwrap().dim.iter().map(|x| x.size as _).collect();
88 let rank = dims.len();
89 let content = &t.tensor_content;
90 let dtype = DataType::from_i32(t.dtype).unwrap();
91 let mat: Tensor = if content.len() != 0 {
92 unsafe {
93 match dtype {
94 DataType::DtFloat => Self::from_raw::<f32>(&dims, content)?,
95 DataType::DtDouble => Self::from_raw::<f64>(&dims, content)?,
96 DataType::DtInt32 => Self::from_raw::<i32>(&dims, content)?,
97 DataType::DtInt64 => Self::from_raw::<i64>(&dims, content)?,
98 _ => unimplemented!("missing type (for get_tensor_content) {:?}", dtype),
99 }
100 }
101 } else {
102 match dtype {
103 DataType::DtInt32 => tensor_from_repeated_field(&dims, t.int_val.to_vec())?,
104 DataType::DtInt64 => tensor_from_repeated_field(&dims, t.int64_val.to_vec())?,
105 DataType::DtFloat => tensor_from_repeated_field(&dims, t.float_val.to_vec())?,
106 DataType::DtDouble => tensor_from_repeated_field(&dims, t.double_val.to_vec())?,
107 DataType::DtString => {
108 let strings = t
109 .string_val
110 .iter()
111 .map(|s| Blob::try_from(&**s))
112 .collect::<TractResult<Vec<Blob>>>()?;
113 tensor_from_repeated_field(&dims, strings)?
114 }
115 _ => unimplemented!("missing type (for _val()) {:?}", t.dtype),
116 }
117 };
118 assert_eq!(rank, mat.shape().len());
119 Ok(mat)
120 }
121}
122
123fn empty_tensor_proto() -> TensorProto {
124 TensorProto {
125 dtype: 0,
126 tensor_shape: None,
127 version_number: 0,
128 tensor_content: vec![],
129 half_val: vec![],
130 float_val: vec![],
131 double_val: vec![],
132 int_val: vec![],
133 string_val: vec![],
134 scomplex_val: vec![],
135 dcomplex_val: vec![],
136 resource_handle_val: vec![],
137 variant_val: vec![],
138 uint32_val: vec![],
139 uint64_val: vec![],
140 int64_val: vec![],
141 bool_val: vec![],
142 }
143}
144
145impl TryFrom<&Tensor> for TensorProto {
146 type Error = TractError;
147 fn try_from(from: &Tensor) -> TractResult<TensorProto> {
148 let mut tensor = empty_tensor_proto();
149 let shape = TensorShapeProto {
150 dim: from.shape().iter().map(|d| Dim { size: *d as _, name: String::new() }).collect(),
151 unknown_rank: false,
152 };
153 tensor.tensor_shape = Some(shape);
154 let dt = DataType::try_from(from.datum_type())?;
155 tensor.dtype = dt.into();
156 match from.datum_type() {
157 DatumType::F32 => {
158 tensor.float_val = from.to_array_view::<f32>()?.iter().cloned().collect();
159 }
160 DatumType::F64 => {
161 tensor.double_val = from.to_array_view::<f64>()?.iter().cloned().collect();
162 }
163 DatumType::I32 => {
164 tensor.int_val = from.to_array_view::<i32>()?.iter().cloned().collect();
165 }
166 DatumType::I64 => {
167 tensor.int64_val = from.to_array_view::<i64>()?.iter().cloned().collect();
168 }
169 _ => unimplemented!("missing type {:?}", from.datum_type()),
170 }
171 Ok(tensor)
172 }
173}