wonnx_preprocessing/
constant_folding.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use protobuf::{ProtobufEnum, RepeatedField};
4use thiserror::Error;
5
6use wonnx::{
7    constant_of_shape_output,
8    onnx::{
9        GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension,
10        TypeProto, TypeProto_Tensor, ValueInfoProto,
11    },
12    utils::{
13        model_with_opset, DataTypeError, InputTensor, NodeAttributes, OutputTensor, ScalarType,
14        Shape,
15    },
16    CompileError, GpuError, Session, SessionError,
17};
18
19#[derive(Error, Debug)]
20pub enum ConstantFoldingError {
21    #[error("unsupported data type encountered: {0}")]
22    #[from(DataTypeError)]
23    UnsupportedDataType(DataTypeError),
24
25    #[error("invalid node: {0}")]
26    InvalidNode(String),
27
28    #[error("error calculating constant value: {0}")]
29    #[from(SessionError)]
30    CalculationError(SessionError),
31}
32
33pub(crate) async fn calculate_constant_node_outputs<'a>(
34    node: &'a NodeProto,
35    shapes: &'a HashMap<String, Shape>,
36    inputs: &'a [InputTensor<'a>],
37    output_shapes: &[Shape],
38    _initializers: &HashMap<String, Cow<'a, TensorProto>>,
39    opset_version: i64,
40) -> Result<Option<Vec<OutputTensor>>, ConstantFoldingError> {
41    Ok(match node.get_op_type() {
42        "Identity" | "Unsqueeze" | "Squeeze" | "Reshape" => {
43            Some(inputs.iter().map(OutputTensor::from).collect())
44        }
45        "Cast" => {
46            let cast_to_type =
47                ScalarType::from_i32(node.get_attribute_value::<i64>("to", None).map_err(|_| {
48                    ConstantFoldingError::InvalidNode("to attribute missing for Cast ".to_string())
49                })? as i32)
50                .map_err(ConstantFoldingError::UnsupportedDataType)?;
51            let input_tensor = &inputs[0];
52
53            let output_tensor = match (input_tensor, cast_to_type) {
54                (InputTensor::F32(v), ScalarType::F32) => OutputTensor::F32(v.to_vec()),
55                (InputTensor::F32(v), ScalarType::I64) => {
56                    OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
57                }
58                (InputTensor::F32(v), ScalarType::I32) => {
59                    OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
60                }
61                (InputTensor::F32(v), ScalarType::U8) => {
62                    OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
63                }
64                (InputTensor::I32(v), ScalarType::F32) => {
65                    OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
66                }
67                (InputTensor::I32(v), ScalarType::I64) => {
68                    OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
69                }
70                (InputTensor::I32(v), ScalarType::I32) => OutputTensor::I32(v.to_vec()),
71                (InputTensor::I32(v), ScalarType::U8) => {
72                    OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
73                }
74                (InputTensor::I64(v), ScalarType::F32) => {
75                    OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
76                }
77                (InputTensor::I64(v), ScalarType::I64) => OutputTensor::I64(v.to_vec()),
78                (InputTensor::I64(v), ScalarType::I32) => {
79                    OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
80                }
81                (InputTensor::I64(v), ScalarType::U8) => {
82                    OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
83                }
84                (InputTensor::U8(v), ScalarType::F32) => {
85                    OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
86                }
87                (InputTensor::U8(v), ScalarType::I64) => {
88                    OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
89                }
90                (InputTensor::U8(v), ScalarType::I32) => {
91                    OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
92                }
93                (InputTensor::U8(v), ScalarType::U8) => OutputTensor::U8(v.to_vec()),
94            };
95
96            Some(vec![output_tensor])
97        }
98
99        // Shape: produces an output containing the shape of the input tensor
100        "Shape" => {
101            let input_shape = &shapes[&node.input[0]];
102            Some(vec![calculate_shape_operator(node, input_shape)?])
103        }
104
105        // ConstantOfShape: produces an output of the shape specified by the input, filled with a constant value specified in an attribute
106        "ConstantOfShape" => {
107            if let InputTensor::I64(input_shape) = &inputs[0] {
108                let element_count = input_shape.iter().product::<i64>() as usize;
109                Some(vec![constant_of_shape_output(node, element_count)
110                    .map_err(|e| {
111                        ConstantFoldingError::InvalidNode(e.to_string())
112                    })?])
113            } else {
114                return Err(ConstantFoldingError::InvalidNode(
115                    "ConstantOfShape node input tensor has invalid type, should be i64".to_string(),
116                ));
117            }
118        }
119
120        _ => {
121            // Try to run on GPU
122            let mut graph = GraphProto::new();
123            graph.set_input(RepeatedField::from(
124                node.input
125                    .iter()
126                    .enumerate()
127                    .map(|(index, input)| {
128                        let shape = &shapes[input];
129                        input_to_value_info(shape, &format!("input_{}", index))
130                    })
131                    .collect::<Vec<_>>(),
132            ));
133
134            graph.set_output(RepeatedField::from(
135                node.output
136                    .iter()
137                    .enumerate()
138                    .map(|(index, _output)| {
139                        let shape = &output_shapes[index];
140                        input_to_value_info(shape, &format!("output_{}", index))
141                    })
142                    .collect::<Vec<_>>(),
143            ));
144
145            let mut temp_node = node.clone();
146            temp_node.set_output(RepeatedField::from(
147                graph
148                    .output
149                    .iter()
150                    .map(|otp| otp.get_name().to_string())
151                    .collect::<Vec<String>>(),
152            ));
153            temp_node.set_input(RepeatedField::from(
154                graph
155                    .input
156                    .iter()
157                    .map(|otp| otp.get_name().to_string())
158                    .collect::<Vec<String>>(),
159            ));
160            graph.set_node(RepeatedField::from(vec![temp_node]));
161
162            let model = model_with_opset(graph, opset_version);
163
164            let session = match Session::from_model(model).await {
165                Ok(v) => v,
166                Err(e) => {
167                    if let SessionError::GpuError(GpuError::CompileError {
168                        error: CompileError::UnimplementedOp(op_name),
169                        ..
170                    }) = e
171                    {
172                        log::info!("could not constant-fold node '{}', because op '{}' is not yet implemented", node.get_name(), op_name);
173                        return Ok(None);
174                    } else {
175                        return Err(ConstantFoldingError::CalculationError(e));
176                    }
177                }
178            };
179
180            let mut named_inputs: HashMap<String, InputTensor> = HashMap::new();
181            for (index, input) in inputs.iter().enumerate() {
182                let input: InputTensor = input.to_owned();
183                named_inputs.insert(format!("input_{}", index), input);
184            }
185
186            let mut output_values = session
187                .run(&named_inputs)
188                .await
189                .map_err(ConstantFoldingError::CalculationError)?;
190
191            let outputs: Vec<OutputTensor> = (0..node.output.len())
192                .map(|output_index| {
193                    let output_key = format!("output_{}", output_index);
194                    output_values.remove(&output_key).unwrap()
195                })
196                .collect();
197
198            Some(outputs)
199        }
200    })
201}
202
203fn input_to_value_info(shape: &Shape, name: &str) -> ValueInfoProto {
204    let mut ttp = TypeProto_Tensor::new();
205    ttp.set_elem_type(shape.data_type.to_datatype().value());
206    let mut tsp = TensorShapeProto::new();
207    tsp.set_dim(RepeatedField::from(
208        shape
209            .dims
210            .iter()
211            .map(|x| {
212                let mut tdp = TensorShapeProto_Dimension::new();
213                tdp.set_dim_value(*x as i64);
214                tdp
215            })
216            .collect::<Vec<TensorShapeProto_Dimension>>(),
217    ));
218    ttp.set_shape(tsp);
219    let mut ftp = TypeProto::new();
220    ftp.set_tensor_type(ttp);
221    let mut vip = ValueInfoProto::new();
222    vip.set_name(name.to_string());
223    vip.set_field_type(ftp);
224    vip
225}
226
227fn calculate_shape_operator(
228    node: &NodeProto,
229    input_shape: &Shape,
230) -> Result<OutputTensor, ConstantFoldingError> {
231    let input_dims: Vec<i64> = input_shape.dims.iter().map(|x| *x as i64).collect();
232    let mut start = node.get_attribute_value("start", Some(0)).unwrap();
233    let mut end = node
234        .get_attribute_value("end", Some(input_dims.len() as i64))
235        .unwrap();
236    if start < 0 {
237        start += input_dims.len() as i64;
238    }
239    if end < 0 {
240        end += input_dims.len() as i64;
241    }
242    start = start.clamp(0, input_dims.len() as i64);
243    end = end.clamp(0, input_dims.len() as i64);
244
245    if start > end {
246        return Err(ConstantFoldingError::InvalidNode(format!(
247            "end attribute value ({}) for Shape node should be higher than start attribute ({})",
248            end, start
249        )));
250    }
251
252    let output_shape: Vec<i64> = (input_dims[(start as usize)..=((end - 1) as usize)]).into();
253    if output_shape.is_empty() {
254        log::warn!("Shape operator results in an empty output shape which is probably an issue... start={start} end={end} input_shape={}", input_shape);
255    }
256
257    Ok(OutputTensor::I64(output_shape))
258}
259
260#[cfg(test)]
261mod test {
262    use wonnx::utils::{attribute, node, OutputTensor, Shape};
263
264    use super::calculate_shape_operator;
265
266    pub fn test_shape_shape_inference_slice(
267        dims: &[i64],
268        start: Option<i64>,
269        end: Option<i64>,
270        out_dims: &[i64],
271    ) {
272        let mut attrs = vec![];
273        if let Some(start) = start {
274            attrs.push(attribute("start", start));
275        }
276        if let Some(end) = end {
277            attrs.push(attribute("end", end));
278        }
279        let node = node(vec!["X"], vec!["Y"], "s", "Shape", attrs);
280        let shape = Shape::from(wonnx::utils::ScalarType::F32, dims);
281        assert_eq!(
282            calculate_shape_operator(&node, &shape).unwrap(),
283            OutputTensor::I64(out_dims.to_vec())
284        );
285    }
286
287    #[test]
288    pub fn test_shape_shape_inference() {
289        test_shape_shape_inference_slice(&[3, 4, 5], None, None, &[3, 4, 5]);
290        test_shape_shape_inference_slice(&[3, 4, 5], Some(1), None, &[4, 5]);
291        test_shape_shape_inference_slice(&[3, 4, 5], Some(10), None, &[]);
292        test_shape_shape_inference_slice(&[3, 4, 5], Some(10), Some(11), &[]);
293
294        test_shape_shape_inference_slice(&[3, 4, 5], Some(-1), None, &[5]);
295        test_shape_shape_inference_slice(&[3, 4, 5], Some(-3), Some(-2), &[3]);
296    }
297}