tract_onnx/
tensor.rs

1use crate::data_resolver::ModelDataResolver;
2use crate::model::ParsingContext;
3use crate::pb::tensor_proto::DataType;
4use crate::pb::*;
5use prost::Message;
6use std::convert::{TryFrom, TryInto};
7use std::path::PathBuf;
8use tract_hir::internal::*;
9
10impl TryFrom<DataType> for DatumType {
11    type Error = TractError;
12    fn try_from(t: DataType) -> TractResult<DatumType> {
13        match t {
14            DataType::Bool => Ok(DatumType::Bool),
15            DataType::Uint8 => Ok(DatumType::U8),
16            DataType::Uint16 => Ok(DatumType::U16),
17            DataType::Uint32 => Ok(DatumType::U32),
18            DataType::Uint64 => Ok(DatumType::U64),
19            DataType::Int8 => Ok(DatumType::I8),
20            DataType::Int16 => Ok(DatumType::I16),
21            DataType::Int32 => Ok(DatumType::I32),
22            DataType::Int64 => Ok(DatumType::I64),
23            DataType::Float16 => Ok(DatumType::F16),
24            DataType::Float => Ok(DatumType::F32),
25            DataType::Double => Ok(DatumType::F64),
26            DataType::String => Ok(DatumType::String),
27            _ => bail!("Unknown DatumType {:?}", t),
28        }
29    }
30}
31
32pub fn translate_inference_fact(
33    ctx: &ParsingContext,
34    t: &type_proto::Tensor,
35    include_unknown_symbols: bool,
36) -> TractResult<InferenceFact> {
37    let mut fact = InferenceFact::default();
38    fact = fact.with_datum_type(DataType::from_i32(t.elem_type).unwrap().try_into()?);
39    if let Some(shape) = &t.shape {
40        let shape: TVec<DimFact> = shape
41            .dim
42            .iter()
43            .map(|d| -> TractResult<DimFact> {
44                match &d.value {
45                    Some(tensor_shape_proto::dimension::Value::DimValue(v)) if *v >= 0 => {
46                        Ok(DimFact::from(v.to_dim()))
47                    }
48                    Some(tensor_shape_proto::dimension::Value::DimParam(v)) => {
49                        if v == "?" || (v.starts_with("unk__") && !include_unknown_symbols) {
50                            Ok(DimFact::default())
51                        } else {
52                            let dim = parse_tdim(&ctx.template.symbols, v)
53                                .with_context(|| format!("Parsing as TDim: `{v}'"))?;
54                            Ok(DimFact::from(dim))
55                        }
56                    }
57                    _ => Ok(DimFact::default()),
58                }
59            })
60            .collect::<TractResult<_>>()?;
61        fact = fact.with_shape(ShapeFactoid::closed(shape));
62    }
63    Ok(fact)
64}
65
66fn get_external_resources(
67    provider: &dyn ModelDataResolver,
68    t: &TensorProto,
69    path: &str,
70) -> TractResult<Vec<u8>> {
71    let mut tensor_data: Vec<u8> = Vec::new();
72    trace!("number of external file needed for this tensor: {}", t.external_data.len());
73    let location = t
74        .external_data
75        .iter()
76        .find(|it| it.key == "location")
77        .map(|it| it.value.as_str())
78        .context("Could not find external data location")?;
79
80    let offset: usize = t
81        .external_data
82        .iter()
83        .find(|it| it.key == "offset")
84        .map(|it| it.value.parse())
85        .transpose()
86        .context("Error while parsing offset value on external data description")?
87        .unwrap_or(0);
88
89    let length: Option<usize> = t
90        .external_data
91        .iter()
92        .find(|it| it.key == "length")
93        .map(|it| it.value.parse())
94        .transpose()
95        .context("Error while parsing length value on external data description")?;
96
97    let p = PathBuf::from(path).join(location);
98
99    trace!("external file detected: {p:?}, offset {offset:?}, length: {length:?}");
100    provider.read_bytes_from_path(&mut tensor_data, &p, offset, length)?;
101    trace!("external file loaded");
102    Ok(tensor_data)
103}
104
105fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<Tensor> {
106    unsafe {
107        match dt {
108            DatumType::U8 => Tensor::from_raw::<u8>(&shape, data),
109            DatumType::U16 => Tensor::from_raw::<u16>(&shape, data),
110            DatumType::U32 => Tensor::from_raw::<u32>(&shape, data),
111            DatumType::U64 => Tensor::from_raw::<u64>(&shape, data),
112            DatumType::I8 => Tensor::from_raw::<i8>(&shape, data),
113            DatumType::I16 => Tensor::from_raw::<i16>(&shape, data),
114            DatumType::I32 => Tensor::from_raw::<i32>(&shape, data),
115            DatumType::I64 => Tensor::from_raw::<i64>(&shape, data),
116            DatumType::F16 => Tensor::from_raw::<f16>(&shape, data),
117            DatumType::F32 => Tensor::from_raw::<f32>(&shape, data),
118            DatumType::F64 => Tensor::from_raw::<f64>(&shape, data),
119            DatumType::Bool => Ok(Tensor::from_raw::<u8>(&shape, data)?
120                .into_array::<u8>()?
121                .mapv(|x| x != 0)
122                .into()),
123            _ => unimplemented!("FIXME, raw tensor loading"),
124        }
125    }
126}
127
128pub fn load_tensor(
129    provider: &dyn ModelDataResolver,
130    t: &TensorProto,
131    path: Option<&str>,
132) -> TractResult<Tensor> {
133    let dt = DataType::from_i32(t.data_type).unwrap().try_into()?;
134    let shape: Vec<usize> = t.dims.iter().map(|&i| i as usize).collect();
135    // detect if the tensor is rather in an external file than inside the onnx file directly
136    let is_external = t.data_location.is_some()
137        && t.data_location == Some(tensor_proto::DataLocation::External as i32);
138    if t.raw_data.len() > 0 {
139        create_tensor(shape, dt, &t.raw_data)
140    } else if is_external {
141        if let Some(model_path) = path {
142            // external files will be loaded and fed to the tensor if necessary
143            let external_data = get_external_resources(provider, t, model_path)?;
144            create_tensor(shape, dt, &external_data)
145        } else {
146            bail!("no model path was specified in the parsing context, yet external data was detected. aborting");
147        }
148    } else {
149        use tract_ndarray::Array;
150        let it = match dt {
151            DatumType::Bool => {
152                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x != 0).collect())?
153                    .into()
154            }
155            DatumType::U8 => {
156                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u8).collect())?
157                    .into()
158            }
159            DatumType::U16 => {
160                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u16).collect())?
161                    .into()
162            }
163            DatumType::U32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
164            DatumType::U64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
165            DatumType::I8 => {
166                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i8).collect())?
167                    .into()
168            }
169            DatumType::I16 => {
170                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i16).collect())?
171                    .into()
172            }
173            DatumType::I32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
174            DatumType::I64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
175            DatumType::F16 => Array::from_shape_vec(
176                &*shape,
177                t.int32_data.iter().map(|&x| f16::from_bits(x as u16)).collect(),
178            )?
179            .into(),
180            DatumType::F32 => Array::from_shape_vec(&*shape, t.float_data.to_vec())?.into(),
181            DatumType::F64 => Array::from_shape_vec(&*shape, t.double_data.to_vec())?.into(),
182            DatumType::String => {
183                let strings = t
184                    .string_data
185                    .iter()
186                    .cloned()
187                    .map(String::from_utf8)
188                    .collect::<Result<Vec<String>, _>>()
189                    .context("Invalid UTF8 buffer")?;
190                Array::from_shape_vec(&*shape, strings)?.into()
191            }
192            _ => unimplemented!("FIXME, struct tensor loading: {:?}", dt),
193        };
194        Ok(it)
195    }
196}
197
198pub fn proto_from_reader<R: ::std::io::Read>(mut r: R) -> TractResult<TensorProto> {
199    let mut v = vec![];
200    r.read_to_end(&mut v)?;
201    let b = bytes::Bytes::from(v);
202    TensorProto::decode(b).context("Can not parse protobuf input")
203}