1use crate::data_resolver::ModelDataResolver;
2use crate::model::ParsingContext;
3use crate::pb::tensor_proto::DataType;
4use crate::pb::*;
5use prost::Message;
6use std::convert::{TryFrom, TryInto};
7use std::path::PathBuf;
8use tract_hir::internal::*;
9
10impl TryFrom<DataType> for DatumType {
11 type Error = TractError;
12 fn try_from(t: DataType) -> TractResult<DatumType> {
13 match t {
14 DataType::Bool => Ok(DatumType::Bool),
15 DataType::Uint8 => Ok(DatumType::U8),
16 DataType::Uint16 => Ok(DatumType::U16),
17 DataType::Uint32 => Ok(DatumType::U32),
18 DataType::Uint64 => Ok(DatumType::U64),
19 DataType::Int8 => Ok(DatumType::I8),
20 DataType::Int16 => Ok(DatumType::I16),
21 DataType::Int32 => Ok(DatumType::I32),
22 DataType::Int64 => Ok(DatumType::I64),
23 DataType::Float16 => Ok(DatumType::F16),
24 DataType::Float => Ok(DatumType::F32),
25 DataType::Double => Ok(DatumType::F64),
26 DataType::String => Ok(DatumType::String),
27 _ => bail!("Unknown DatumType {:?}", t),
28 }
29 }
30}
31
32pub fn translate_inference_fact(
33 ctx: &ParsingContext,
34 t: &type_proto::Tensor,
35 include_unknown_symbols: bool,
36) -> TractResult<InferenceFact> {
37 let mut fact = InferenceFact::default();
38 fact = fact.with_datum_type(DataType::from_i32(t.elem_type).unwrap().try_into()?);
39 if let Some(shape) = &t.shape {
40 let shape: TVec<DimFact> = shape
41 .dim
42 .iter()
43 .map(|d| -> TractResult<DimFact> {
44 match &d.value {
45 Some(tensor_shape_proto::dimension::Value::DimValue(v)) if *v >= 0 => {
46 Ok(DimFact::from(v.to_dim()))
47 }
48 Some(tensor_shape_proto::dimension::Value::DimParam(v)) => {
49 if v == "?" || (v.starts_with("unk__") && !include_unknown_symbols) {
50 Ok(DimFact::default())
51 } else {
52 let dim = parse_tdim(&ctx.template.symbols, v)
53 .with_context(|| format!("Parsing as TDim: `{v}'"))?;
54 Ok(DimFact::from(dim))
55 }
56 }
57 _ => Ok(DimFact::default()),
58 }
59 })
60 .collect::<TractResult<_>>()?;
61 fact = fact.with_shape(ShapeFactoid::closed(shape));
62 }
63 Ok(fact)
64}
65
66fn get_external_resources(
67 provider: &dyn ModelDataResolver,
68 t: &TensorProto,
69 path: &str,
70) -> TractResult<Vec<u8>> {
71 let mut tensor_data: Vec<u8> = Vec::new();
72 trace!("number of external file needed for this tensor: {}", t.external_data.len());
73 let location = t
74 .external_data
75 .iter()
76 .find(|it| it.key == "location")
77 .map(|it| it.value.as_str())
78 .context("Could not find external data location")?;
79
80 let offset: usize = t
81 .external_data
82 .iter()
83 .find(|it| it.key == "offset")
84 .map(|it| it.value.parse())
85 .transpose()
86 .context("Error while parsing offset value on external data description")?
87 .unwrap_or(0);
88
89 let length: Option<usize> = t
90 .external_data
91 .iter()
92 .find(|it| it.key == "length")
93 .map(|it| it.value.parse())
94 .transpose()
95 .context("Error while parsing length value on external data description")?;
96
97 let p = PathBuf::from(path).join(location);
98
99 trace!("external file detected: {p:?}, offset {offset:?}, length: {length:?}");
100 provider.read_bytes_from_path(&mut tensor_data, &p, offset, length)?;
101 trace!("external file loaded");
102 Ok(tensor_data)
103}
104
105fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<Tensor> {
106 unsafe {
107 match dt {
108 DatumType::U8 => Tensor::from_raw::<u8>(&shape, data),
109 DatumType::U16 => Tensor::from_raw::<u16>(&shape, data),
110 DatumType::U32 => Tensor::from_raw::<u32>(&shape, data),
111 DatumType::U64 => Tensor::from_raw::<u64>(&shape, data),
112 DatumType::I8 => Tensor::from_raw::<i8>(&shape, data),
113 DatumType::I16 => Tensor::from_raw::<i16>(&shape, data),
114 DatumType::I32 => Tensor::from_raw::<i32>(&shape, data),
115 DatumType::I64 => Tensor::from_raw::<i64>(&shape, data),
116 DatumType::F16 => Tensor::from_raw::<f16>(&shape, data),
117 DatumType::F32 => Tensor::from_raw::<f32>(&shape, data),
118 DatumType::F64 => Tensor::from_raw::<f64>(&shape, data),
119 DatumType::Bool => Ok(Tensor::from_raw::<u8>(&shape, data)?
120 .into_array::<u8>()?
121 .mapv(|x| x != 0)
122 .into()),
123 _ => unimplemented!("FIXME, raw tensor loading"),
124 }
125 }
126}
127
128pub fn load_tensor(
129 provider: &dyn ModelDataResolver,
130 t: &TensorProto,
131 path: Option<&str>,
132) -> TractResult<Tensor> {
133 let dt = DataType::from_i32(t.data_type).unwrap().try_into()?;
134 let shape: Vec<usize> = t.dims.iter().map(|&i| i as usize).collect();
135 let is_external = t.data_location.is_some()
137 && t.data_location == Some(tensor_proto::DataLocation::External as i32);
138 if t.raw_data.len() > 0 {
139 create_tensor(shape, dt, &t.raw_data)
140 } else if is_external {
141 if let Some(model_path) = path {
142 let external_data = get_external_resources(provider, t, model_path)?;
144 create_tensor(shape, dt, &external_data)
145 } else {
146 bail!("no model path was specified in the parsing context, yet external data was detected. aborting");
147 }
148 } else {
149 use tract_ndarray::Array;
150 let it = match dt {
151 DatumType::Bool => {
152 Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x != 0).collect())?
153 .into()
154 }
155 DatumType::U8 => {
156 Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u8).collect())?
157 .into()
158 }
159 DatumType::U16 => {
160 Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u16).collect())?
161 .into()
162 }
163 DatumType::U32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
164 DatumType::U64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
165 DatumType::I8 => {
166 Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i8).collect())?
167 .into()
168 }
169 DatumType::I16 => {
170 Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i16).collect())?
171 .into()
172 }
173 DatumType::I32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
174 DatumType::I64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
175 DatumType::F16 => Array::from_shape_vec(
176 &*shape,
177 t.int32_data.iter().map(|&x| f16::from_bits(x as u16)).collect(),
178 )?
179 .into(),
180 DatumType::F32 => Array::from_shape_vec(&*shape, t.float_data.to_vec())?.into(),
181 DatumType::F64 => Array::from_shape_vec(&*shape, t.double_data.to_vec())?.into(),
182 DatumType::String => {
183 let strings = t
184 .string_data
185 .iter()
186 .cloned()
187 .map(String::from_utf8)
188 .collect::<Result<Vec<String>, _>>()
189 .context("Invalid UTF8 buffer")?;
190 Array::from_shape_vec(&*shape, strings)?.into()
191 }
192 _ => unimplemented!("FIXME, struct tensor loading: {:?}", dt),
193 };
194 Ok(it)
195 }
196}
197
198pub fn proto_from_reader<R: ::std::io::Read>(mut r: R) -> TractResult<TensorProto> {
199 let mut v = vec![];
200 r.read_to_end(&mut v)?;
201 let b = bytes::Bytes::from(v);
202 TensorProto::decode(b).context("Can not parse protobuf input")
203}