tract_tensorflow/
tensor.rs

1use tract_hir::internal::*;
2
3use crate::tfpb::tensorflow::tensor_shape_proto::Dim;
4use crate::tfpb::tensorflow::{TensorProto, TensorShapeProto};
5
6use crate::tfpb::tensorflow::DataType;
7use std::convert::TryFrom;
8
9impl TryFrom<DataType> for DatumType {
10    type Error = TractError;
11    fn try_from(t: DataType) -> TractResult<DatumType> {
12        match t {
13            DataType::DtBool => Ok(DatumType::Bool),
14            DataType::DtUint8 => Ok(DatumType::U8),
15            DataType::DtUint16 => Ok(DatumType::U16),
16            DataType::DtUint32 => Ok(DatumType::U32),
17            DataType::DtUint64 => Ok(DatumType::U64),
18            DataType::DtInt8 => Ok(DatumType::I8),
19            DataType::DtInt16 => Ok(DatumType::I16),
20            DataType::DtInt32 => Ok(DatumType::I32),
21            DataType::DtInt64 => Ok(DatumType::I64),
22            DataType::DtHalf => Ok(DatumType::F16),
23            DataType::DtFloat => Ok(DatumType::F32),
24            DataType::DtDouble => Ok(DatumType::F64),
25            DataType::DtString => Ok(DatumType::Blob),
26            _ => Err(format_err!("Unknown DatumType {:?}", t))?,
27        }
28    }
29}
30
31impl<'a> TryFrom<&'a TensorShapeProto> for TVec<isize> {
32    type Error = TractError;
33    fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<isize>> {
34        Ok(t.dim.iter().map(|d| d.size as isize).collect::<TVec<_>>())
35    }
36}
37
38impl<'a> TryFrom<&'a TensorShapeProto> for TVec<usize> {
39    type Error = TractError;
40    fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<usize>> {
41        if t.dim.iter().any(|d| d.size < 0) {
42            bail!("Negative dim found")
43        }
44        Ok(t.dim.iter().map(|d| d.size as usize).collect::<TVec<_>>())
45    }
46}
47
48impl TryFrom<DatumType> for DataType {
49    type Error = TractError;
50    fn try_from(dt: DatumType) -> TractResult<DataType> {
51        match dt {
52            DatumType::Bool => Ok(DataType::DtBool),
53            DatumType::U8 => Ok(DataType::DtUint8),
54            DatumType::U16 => Ok(DataType::DtUint16),
55            DatumType::U32 => Ok(DataType::DtUint32),
56            DatumType::U64 => Ok(DataType::DtUint64),
57            DatumType::I8 => Ok(DataType::DtInt8),
58            DatumType::I16 => Ok(DataType::DtInt16),
59            DatumType::I32 => Ok(DataType::DtInt32),
60            DatumType::I64 => Ok(DataType::DtInt64),
61            DatumType::F16 => Ok(DataType::DtHalf),
62            DatumType::F32 => Ok(DataType::DtFloat),
63            DatumType::F64 => Ok(DataType::DtDouble),
64            DatumType::Blob => Ok(DataType::DtString),
65            DatumType::String => Ok(DataType::DtString),
66            DatumType::QI8(_) => Ok(DataType::DtQint8),
67            DatumType::QU8(_) => Ok(DataType::DtQint8),
68            DatumType::QI32(_) => Ok(DataType::DtQint32),
69            _ => bail!("DatumType is not translatable in protobuf"),
70        }
71    }
72}
73
74fn tensor_from_repeated_field<T: Datum>(shape: &[usize], data: Vec<T>) -> TractResult<Tensor> {
75    let t = if data.len() == 1 {
76        tract_ndarray::ArrayD::from_elem(shape, data[0].clone()).into()
77    } else {
78        tract_ndarray::ArrayD::from_shape_vec(shape, data.to_vec())?.into()
79    };
80    Ok(t)
81}
82
83impl TryFrom<&TensorProto> for Tensor {
84    type Error = TractError;
85    fn try_from(t: &TensorProto) -> TractResult<Tensor> {
86        let dims: TVec<usize> =
87            t.tensor_shape.as_ref().unwrap().dim.iter().map(|x| x.size as _).collect();
88        let rank = dims.len();
89        let content = &t.tensor_content;
90        let dtype = DataType::from_i32(t.dtype).unwrap();
91        let mat: Tensor = if content.len() != 0 {
92            unsafe {
93                match dtype {
94                    DataType::DtFloat => Self::from_raw::<f32>(&dims, content)?,
95                    DataType::DtDouble => Self::from_raw::<f64>(&dims, content)?,
96                    DataType::DtInt32 => Self::from_raw::<i32>(&dims, content)?,
97                    DataType::DtInt64 => Self::from_raw::<i64>(&dims, content)?,
98                    _ => unimplemented!("missing type (for get_tensor_content) {:?}", dtype),
99                }
100            }
101        } else {
102            match dtype {
103                DataType::DtInt32 => tensor_from_repeated_field(&dims, t.int_val.to_vec())?,
104                DataType::DtInt64 => tensor_from_repeated_field(&dims, t.int64_val.to_vec())?,
105                DataType::DtFloat => tensor_from_repeated_field(&dims, t.float_val.to_vec())?,
106                DataType::DtDouble => tensor_from_repeated_field(&dims, t.double_val.to_vec())?,
107                DataType::DtString => {
108                    let strings = t
109                        .string_val
110                        .iter()
111                        .map(|s| Blob::try_from(&**s))
112                        .collect::<TractResult<Vec<Blob>>>()?;
113                    tensor_from_repeated_field(&dims, strings)?
114                }
115                _ => unimplemented!("missing type (for _val()) {:?}", t.dtype),
116            }
117        };
118        assert_eq!(rank, mat.shape().len());
119        Ok(mat)
120    }
121}
122
123fn empty_tensor_proto() -> TensorProto {
124    TensorProto {
125        dtype: 0,
126        tensor_shape: None,
127        version_number: 0,
128        tensor_content: vec![],
129        half_val: vec![],
130        float_val: vec![],
131        double_val: vec![],
132        int_val: vec![],
133        string_val: vec![],
134        scomplex_val: vec![],
135        dcomplex_val: vec![],
136        resource_handle_val: vec![],
137        variant_val: vec![],
138        uint32_val: vec![],
139        uint64_val: vec![],
140        int64_val: vec![],
141        bool_val: vec![],
142    }
143}
144
145impl TryFrom<&Tensor> for TensorProto {
146    type Error = TractError;
147    fn try_from(from: &Tensor) -> TractResult<TensorProto> {
148        let mut tensor = empty_tensor_proto();
149        let shape = TensorShapeProto {
150            dim: from.shape().iter().map(|d| Dim { size: *d as _, name: String::new() }).collect(),
151            unknown_rank: false,
152        };
153        tensor.tensor_shape = Some(shape);
154        let dt = DataType::try_from(from.datum_type())?;
155        tensor.dtype = dt.into();
156        match from.datum_type() {
157            DatumType::F32 => {
158                tensor.float_val = from.to_array_view::<f32>()?.iter().cloned().collect();
159            }
160            DatumType::F64 => {
161                tensor.double_val = from.to_array_view::<f64>()?.iter().cloned().collect();
162            }
163            DatumType::I32 => {
164                tensor.int_val = from.to_array_view::<i32>()?.iter().cloned().collect();
165            }
166            DatumType::I64 => {
167                tensor.int64_val = from.to_array_view::<i64>()?.iter().cloned().collect();
168            }
169            _ => unimplemented!("missing type {:?}", from.datum_type()),
170        }
171        Ok(tensor)
172    }
173}