1use std::{borrow::Cow, collections::HashMap};
2
3use protobuf::{ProtobufEnum, RepeatedField};
4use thiserror::Error;
5
6use wonnx::{
7 constant_of_shape_output,
8 onnx::{
9 GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension,
10 TypeProto, TypeProto_Tensor, ValueInfoProto,
11 },
12 utils::{
13 model_with_opset, DataTypeError, InputTensor, NodeAttributes, OutputTensor, ScalarType,
14 Shape,
15 },
16 CompileError, GpuError, Session, SessionError,
17};
18
19#[derive(Error, Debug)]
20pub enum ConstantFoldingError {
21 #[error("unsupported data type encountered: {0}")]
22 #[from(DataTypeError)]
23 UnsupportedDataType(DataTypeError),
24
25 #[error("invalid node: {0}")]
26 InvalidNode(String),
27
28 #[error("error calculating constant value: {0}")]
29 #[from(SessionError)]
30 CalculationError(SessionError),
31}
32
33pub(crate) async fn calculate_constant_node_outputs<'a>(
34 node: &'a NodeProto,
35 shapes: &'a HashMap<String, Shape>,
36 inputs: &'a [InputTensor<'a>],
37 output_shapes: &[Shape],
38 _initializers: &HashMap<String, Cow<'a, TensorProto>>,
39 opset_version: i64,
40) -> Result<Option<Vec<OutputTensor>>, ConstantFoldingError> {
41 Ok(match node.get_op_type() {
42 "Identity" | "Unsqueeze" | "Squeeze" | "Reshape" => {
43 Some(inputs.iter().map(OutputTensor::from).collect())
44 }
45 "Cast" => {
46 let cast_to_type =
47 ScalarType::from_i32(node.get_attribute_value::<i64>("to", None).map_err(|_| {
48 ConstantFoldingError::InvalidNode("to attribute missing for Cast ".to_string())
49 })? as i32)
50 .map_err(ConstantFoldingError::UnsupportedDataType)?;
51 let input_tensor = &inputs[0];
52
53 let output_tensor = match (input_tensor, cast_to_type) {
54 (InputTensor::F32(v), ScalarType::F32) => OutputTensor::F32(v.to_vec()),
55 (InputTensor::F32(v), ScalarType::I64) => {
56 OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
57 }
58 (InputTensor::F32(v), ScalarType::I32) => {
59 OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
60 }
61 (InputTensor::F32(v), ScalarType::U8) => {
62 OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
63 }
64 (InputTensor::I32(v), ScalarType::F32) => {
65 OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
66 }
67 (InputTensor::I32(v), ScalarType::I64) => {
68 OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
69 }
70 (InputTensor::I32(v), ScalarType::I32) => OutputTensor::I32(v.to_vec()),
71 (InputTensor::I32(v), ScalarType::U8) => {
72 OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
73 }
74 (InputTensor::I64(v), ScalarType::F32) => {
75 OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
76 }
77 (InputTensor::I64(v), ScalarType::I64) => OutputTensor::I64(v.to_vec()),
78 (InputTensor::I64(v), ScalarType::I32) => {
79 OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
80 }
81 (InputTensor::I64(v), ScalarType::U8) => {
82 OutputTensor::U8(v.iter().map(|x| *x as u8).collect())
83 }
84 (InputTensor::U8(v), ScalarType::F32) => {
85 OutputTensor::F32(v.iter().map(|x| *x as f32).collect())
86 }
87 (InputTensor::U8(v), ScalarType::I64) => {
88 OutputTensor::I64(v.iter().map(|x| *x as i64).collect())
89 }
90 (InputTensor::U8(v), ScalarType::I32) => {
91 OutputTensor::I32(v.iter().map(|x| *x as i32).collect())
92 }
93 (InputTensor::U8(v), ScalarType::U8) => OutputTensor::U8(v.to_vec()),
94 };
95
96 Some(vec![output_tensor])
97 }
98
99 "Shape" => {
101 let input_shape = &shapes[&node.input[0]];
102 Some(vec![calculate_shape_operator(node, input_shape)?])
103 }
104
105 "ConstantOfShape" => {
107 if let InputTensor::I64(input_shape) = &inputs[0] {
108 let element_count = input_shape.iter().product::<i64>() as usize;
109 Some(vec![constant_of_shape_output(node, element_count)
110 .map_err(|e| {
111 ConstantFoldingError::InvalidNode(e.to_string())
112 })?])
113 } else {
114 return Err(ConstantFoldingError::InvalidNode(
115 "ConstantOfShape node input tensor has invalid type, should be i64".to_string(),
116 ));
117 }
118 }
119
120 _ => {
121 let mut graph = GraphProto::new();
123 graph.set_input(RepeatedField::from(
124 node.input
125 .iter()
126 .enumerate()
127 .map(|(index, input)| {
128 let shape = &shapes[input];
129 input_to_value_info(shape, &format!("input_{}", index))
130 })
131 .collect::<Vec<_>>(),
132 ));
133
134 graph.set_output(RepeatedField::from(
135 node.output
136 .iter()
137 .enumerate()
138 .map(|(index, _output)| {
139 let shape = &output_shapes[index];
140 input_to_value_info(shape, &format!("output_{}", index))
141 })
142 .collect::<Vec<_>>(),
143 ));
144
145 let mut temp_node = node.clone();
146 temp_node.set_output(RepeatedField::from(
147 graph
148 .output
149 .iter()
150 .map(|otp| otp.get_name().to_string())
151 .collect::<Vec<String>>(),
152 ));
153 temp_node.set_input(RepeatedField::from(
154 graph
155 .input
156 .iter()
157 .map(|otp| otp.get_name().to_string())
158 .collect::<Vec<String>>(),
159 ));
160 graph.set_node(RepeatedField::from(vec![temp_node]));
161
162 let model = model_with_opset(graph, opset_version);
163
164 let session = match Session::from_model(model).await {
165 Ok(v) => v,
166 Err(e) => {
167 if let SessionError::GpuError(GpuError::CompileError {
168 error: CompileError::UnimplementedOp(op_name),
169 ..
170 }) = e
171 {
172 log::info!("could not constant-fold node '{}', because op '{}' is not yet implemented", node.get_name(), op_name);
173 return Ok(None);
174 } else {
175 return Err(ConstantFoldingError::CalculationError(e));
176 }
177 }
178 };
179
180 let mut named_inputs: HashMap<String, InputTensor> = HashMap::new();
181 for (index, input) in inputs.iter().enumerate() {
182 let input: InputTensor = input.to_owned();
183 named_inputs.insert(format!("input_{}", index), input);
184 }
185
186 let mut output_values = session
187 .run(&named_inputs)
188 .await
189 .map_err(ConstantFoldingError::CalculationError)?;
190
191 let outputs: Vec<OutputTensor> = (0..node.output.len())
192 .map(|output_index| {
193 let output_key = format!("output_{}", output_index);
194 output_values.remove(&output_key).unwrap()
195 })
196 .collect();
197
198 Some(outputs)
199 }
200 })
201}
202
203fn input_to_value_info(shape: &Shape, name: &str) -> ValueInfoProto {
204 let mut ttp = TypeProto_Tensor::new();
205 ttp.set_elem_type(shape.data_type.to_datatype().value());
206 let mut tsp = TensorShapeProto::new();
207 tsp.set_dim(RepeatedField::from(
208 shape
209 .dims
210 .iter()
211 .map(|x| {
212 let mut tdp = TensorShapeProto_Dimension::new();
213 tdp.set_dim_value(*x as i64);
214 tdp
215 })
216 .collect::<Vec<TensorShapeProto_Dimension>>(),
217 ));
218 ttp.set_shape(tsp);
219 let mut ftp = TypeProto::new();
220 ftp.set_tensor_type(ttp);
221 let mut vip = ValueInfoProto::new();
222 vip.set_name(name.to_string());
223 vip.set_field_type(ftp);
224 vip
225}
226
227fn calculate_shape_operator(
228 node: &NodeProto,
229 input_shape: &Shape,
230) -> Result<OutputTensor, ConstantFoldingError> {
231 let input_dims: Vec<i64> = input_shape.dims.iter().map(|x| *x as i64).collect();
232 let mut start = node.get_attribute_value("start", Some(0)).unwrap();
233 let mut end = node
234 .get_attribute_value("end", Some(input_dims.len() as i64))
235 .unwrap();
236 if start < 0 {
237 start += input_dims.len() as i64;
238 }
239 if end < 0 {
240 end += input_dims.len() as i64;
241 }
242 start = start.clamp(0, input_dims.len() as i64);
243 end = end.clamp(0, input_dims.len() as i64);
244
245 if start > end {
246 return Err(ConstantFoldingError::InvalidNode(format!(
247 "end attribute value ({}) for Shape node should be higher than start attribute ({})",
248 end, start
249 )));
250 }
251
252 let output_shape: Vec<i64> = (input_dims[(start as usize)..=((end - 1) as usize)]).into();
253 if output_shape.is_empty() {
254 log::warn!("Shape operator results in an empty output shape which is probably an issue... start={start} end={end} input_shape={}", input_shape);
255 }
256
257 Ok(OutputTensor::I64(output_shape))
258}
259
260#[cfg(test)]
261mod test {
262 use wonnx::utils::{attribute, node, OutputTensor, Shape};
263
264 use super::calculate_shape_operator;
265
266 pub fn test_shape_shape_inference_slice(
267 dims: &[i64],
268 start: Option<i64>,
269 end: Option<i64>,
270 out_dims: &[i64],
271 ) {
272 let mut attrs = vec![];
273 if let Some(start) = start {
274 attrs.push(attribute("start", start));
275 }
276 if let Some(end) = end {
277 attrs.push(attribute("end", end));
278 }
279 let node = node(vec!["X"], vec!["Y"], "s", "Shape", attrs);
280 let shape = Shape::from(wonnx::utils::ScalarType::F32, dims);
281 assert_eq!(
282 calculate_shape_operator(&node, &shape).unwrap(),
283 OutputTensor::I64(out_dims.to_vec())
284 );
285 }
286
287 #[test]
288 pub fn test_shape_shape_inference() {
289 test_shape_shape_inference_slice(&[3, 4, 5], None, None, &[3, 4, 5]);
290 test_shape_shape_inference_slice(&[3, 4, 5], Some(1), None, &[4, 5]);
291 test_shape_shape_inference_slice(&[3, 4, 5], Some(10), None, &[]);
292 test_shape_shape_inference_slice(&[3, 4, 5], Some(10), Some(11), &[]);
293
294 test_shape_shape_inference_slice(&[3, 4, 5], Some(-1), None, &[5]);
295 test_shape_shape_inference_slice(&[3, 4, 5], Some(-3), Some(-2), &[3]);
296 }
297}