1use std::{borrow::Cow, collections::HashMap};
2
3use protobuf::ProtobufEnum;
4use thiserror::Error;
5use wonnx::{
6 onnx::{
7 GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension,
8 TypeProto, TypeProto_Tensor, TypeProto_oneof_value, ValueInfoProto,
9 },
10 utils::{
11 AttributeNotFoundError, DataTypeError, InputTensor, NodeAttributes, ScalarType, Shape,
12 },
13};
14
15use crate::constant_folding::{calculate_constant_node_outputs, ConstantFoldingError};
16
17pub fn apply_dynamic_dimensions(graph: &mut GraphProto, dynamic_dims: &HashMap<String, i64>) {
18 for value_info in graph.mut_value_info() {
20 apply_dynamic_dimensions_value(value_info, dynamic_dims);
21 }
22
23 for value_info in graph.mut_input() {
24 apply_dynamic_dimensions_value(value_info, dynamic_dims);
25 }
26
27 for value_info in graph.mut_output() {
28 apply_dynamic_dimensions_value(value_info, dynamic_dims);
29 }
30}
31
32fn div_ceil(num: i64, div: i64) -> i64 {
34 num / div + (num % div != 0) as i64
35}
36
37fn static_initializer_value_i64<'a>(
39 initializers: &'a HashMap<String, Cow<'a, TensorProto>>,
40 name: &str,
41) -> Result<&'a [i64], ShapeInferenceError> {
42 if let Some(shape_tensor) = initializers.get(name) {
43 if shape_tensor.get_data_type() != ScalarType::I64.to_datatype().value() {
44 return Err(ShapeInferenceError::Unsupported(format!(
45 "initializer {} has data type {} and not int64, which is currently not supported",
46 name,
47 shape_tensor.get_data_type()
48 )));
49 }
50
51 let expected_value_count: i64 = shape_tensor.get_dims().iter().product();
52
53 if shape_tensor.get_int64_data().len() != expected_value_count as usize {
56 let raw_data = shape_tensor.get_raw_data();
57 if raw_data.len() / 8 == expected_value_count as usize {
58 log::warn!(
60 "int64 data for initializer {name} contains {} values, expected {expected_value_count}. Raw data length ({}) matches however, using that. dims={:?}",
61 shape_tensor.get_int64_data().len(),
62 shape_tensor.get_raw_data().len(),
63 shape_tensor.get_dims()
64 );
65 return Ok(bytemuck::cast_slice(raw_data));
66 } else {
67 log::warn!(
68 "int64 data for initializer {name} contains {} values, expected {expected_value_count}. Raw data length ({}) doesn't match either! dims={:?}",
69 shape_tensor.get_int64_data().len(),
70 shape_tensor.get_raw_data().len(),
71 shape_tensor.get_dims()
72 );
73 }
74 }
75
76 Ok(shape_tensor.get_int64_data())
78 } else {
79 Err(ShapeInferenceError::Unsupported(format!(
80 "input {} is dynamic (only static initializers are supported)",
81 name
82 )))
83 }
84}
85
86fn apply_dynamic_dimensions_value(
88 value_info: &mut ValueInfoProto,
89 dynamic_dims: &HashMap<String, i64>,
90) {
91 let name = value_info.get_name().to_string();
92 let field_type = value_info.mut_field_type();
93
94 if let Some(TypeProto_oneof_value::tensor_type(field_type_value)) = &mut field_type.value {
95 let dims = field_type_value.mut_shape().mut_dim();
96
97 for (idx, dim) in dims.iter_mut().enumerate() {
98 if let Some(new_dim_value) = dynamic_dims.get(dim.get_dim_param()) {
99 println!(
100 "Setting dimension param {idx} ({}) to value {new_dim_value} for {name}",
101 dim.get_dim_param()
102 );
103 dim.clear_dim_param();
104 dim.set_dim_value(*new_dim_value);
105 }
106 }
107 }
108}
109
110pub(crate) fn dimensions_infos(
112 graph_proto: &GraphProto,
113) -> Result<HashMap<String, Shape>, DataTypeError> {
114 let mut shapes_info = HashMap::new();
115
116 for info in graph_proto.get_input() {
117 if let Ok(shape) = info.get_shape() {
118 shapes_info.insert(info.get_name().to_string(), shape);
119 }
120 }
121
122 for info in graph_proto.get_output() {
123 if let Ok(shape) = info.get_shape() {
124 if shapes_info
125 .insert(info.get_name().to_string(), shape)
126 .is_some()
127 {
128 log::warn!(
129 "already had shape information for '{}', replacing from outputs",
130 info.get_name()
131 );
132 }
133 }
134 }
135
136 for info in graph_proto.get_value_info() {
137 if let Ok(shape) = info.get_shape() {
138 if shapes_info
139 .insert(info.get_name().to_string(), shape)
140 .is_some()
141 {
142 log::warn!(
143 "already had shape information for '{}', replacing from value_info",
144 info.get_name()
145 );
146 }
147 }
148 }
149
150 for info in graph_proto.get_initializer() {
151 if let Ok(data_type) = ScalarType::from_i32(info.get_data_type()) {
152 let shape = Shape::from(data_type, info.get_dims());
153 if shapes_info
154 .insert(info.get_name().to_string(), shape)
155 .is_some()
156 {
157 log::warn!(
158 "already shape information for '{}', replacing from initializer",
159 info.get_name()
160 );
161 }
162 }
163 }
164
165 Ok(shapes_info)
166}
167
168#[derive(Error, Debug)]
169pub enum ShapeInferenceError {
170 #[error("missing shape for input {0}")]
171 MissingInputShape(String),
172
173 #[error("incomplete or missing shape for input {0} - be sure to specify all dynamic dimension parameters")]
174 IncompleteInputShape(String),
175
176 #[error("unsupported: {0}")]
177 Unsupported(String),
178
179 #[error("node {0} is invalid: {1}")]
180 InvalidNode(String, String),
181
182 #[error("attribute {0} required for shape inference is missing")]
183 #[from(AttributeNotFoundError)]
184 MissingAttribute(AttributeNotFoundError),
185
186 #[error("unsupported data type encountered: {0}")]
187 #[from(DataTypeError)]
188 UnsupportedDataType(DataTypeError),
189
190 #[error("constant folding failed: {0}")]
191 #[from(ConstantFoldingError)]
192 ConstantFoldingError(ConstantFoldingError),
193}
194
195fn replace_constant_ops_with_initializers(
197 graph: &mut GraphProto,
198) -> Result<(), ShapeInferenceError> {
199 for node_index in (0..graph.node.len()).rev() {
200 let is_constant = graph.node[node_index].get_op_type() == "Constant";
201
202 if is_constant {
203 {
204 let node = &graph.node[node_index];
205 if node.get_output().len() != 1 {
206 return Err(ShapeInferenceError::InvalidNode(
207 node.get_name().to_string(),
208 format!(
209 "Constant op must have one output, has {}",
210 node.get_output().len()
211 ),
212 ));
213 }
214
215 let mut initializer = TensorProto::new();
217
218 if let Ok(values) = node.get_attribute_value::<Vec<f32>>("value_floats", None) {
220 initializer.set_data_type(ScalarType::F32.to_datatype().value());
221 initializer.set_dims(vec![values.len() as i64]);
222 initializer.set_float_data(values);
223 } else if let Ok(values) = node.get_attribute_value::<Vec<i64>>("value_ints", None)
224 {
225 initializer.set_data_type(ScalarType::I64.to_datatype().value());
226 initializer.set_dims(vec![values.len() as i64]);
227 initializer.set_int64_data(values);
228 } else if let Ok(values) = node.get_attribute_value::<i64>("value_int", None) {
229 initializer.set_int64_data(vec![values]);
230 initializer.set_data_type(ScalarType::I64.to_datatype().value());
231 initializer.set_dims(vec![1]);
232 } else if let Ok(values) = node.get_attribute_value::<f32>("value_float", None) {
233 initializer.set_float_data(vec![values]);
234 initializer.set_data_type(ScalarType::F32.to_datatype().value());
235 initializer.set_dims(vec![1]);
236 } else if let Ok(tp) = node.get_attribute_value::<TensorProto>("value", None) {
237 initializer = tp;
238 fix_raw_tensor(&mut initializer)?;
239 } else {
240 log::debug!("Constant node attributes: {:?}", node.attribute);
241 return Err(ShapeInferenceError::Unsupported(
242 "Constant node with data types other than float, int".to_string(),
243 ));
244 }
245
246 log::info!(
247 "Replacing Constant node '{}' with an initializer (name='{}', shape={:?})",
248 node.get_name(),
249 node.output[0].clone(),
250 initializer.dims
251 );
252
253 initializer.set_name(node.output[0].clone()); graph.initializer.push(initializer);
255 }
256 graph.node.remove(node_index);
257 }
258 }
259 Ok(())
260}
261
262pub async fn infer_shapes(
263 graph: &mut GraphProto,
264 should_fold_constants: bool,
265 opset_version: i64,
266) -> Result<(), ShapeInferenceError> {
267 let mut foldable_nodes: Vec<String> = vec![];
268 let mut folded_node_indexes: Vec<usize> = vec![];
269
270 if should_fold_constants {
271 replace_constant_ops_with_initializers(graph)?;
272 }
273
274 let mut shapes = dimensions_infos(graph).map_err(ShapeInferenceError::UnsupportedDataType)?;
275
276 let mut initializers: HashMap<String, Cow<TensorProto>> = HashMap::from_iter(
278 graph
279 .initializer
280 .iter()
281 .map(|x| (x.get_name().to_string(), Cow::Borrowed(x))),
282 );
283
284 for (node_index, node) in graph.node.iter().enumerate() {
285 log::debug!(
286 "node: {} {} inputs {} -> outputs {}",
287 node.get_op_type(),
288 node.get_name(),
289 node.get_input().join(", "),
290 node.get_output().join(", ")
291 );
292
293 if node
295 .get_output()
296 .iter()
297 .any(|output_name| !shapes.contains_key(output_name.as_str()))
298 {
299 log::debug!("node needs shape inference: {}", node.get_name());
300
301 let input_shapes: Vec<&Shape> = node
302 .get_input()
303 .iter()
304 .map(|name| {
305 shapes
306 .get(name)
307 .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))
308 })
309 .collect::<Result<_, ShapeInferenceError>>()?;
310
311 let output_shapes = infer_output_shapes(node, &input_shapes, &initializers)?;
312
313 for (output_index, shape) in output_shapes.iter().enumerate() {
315 if shape.rank() == 0 {
316 log::warn!(
317 "inferred shape for output {output_index} of node '{}' is empty: {shape}",
318 node.get_name()
319 );
320 }
321 }
322
323 log::info!(
324 "node {} inferred output shapes: {}",
325 node.get_name(),
326 output_shapes
327 .iter()
328 .enumerate()
329 .map(|(idx, x)| format!("{}={x}", node.output[idx]))
330 .collect::<Vec<String>>()
331 .join(", ")
332 );
333
334 if output_shapes.len() != node.get_output().len() {
335 panic!("number of outputs inferred does not match node output count");
336 }
337
338 for (output_idx, output_name) in node.get_output().iter().enumerate() {
340 let output_shape = &output_shapes[output_idx];
341 shapes.insert(output_name.clone(), output_shape.clone());
342 let mut vip = ValueInfoProto::new();
343 vip.set_name(output_name.clone());
344
345 let mut tip = TypeProto::new();
346 let mut ttp = TypeProto_Tensor::new();
347 ttp.set_elem_type(output_shape.data_type.to_datatype().value());
348
349 let mut tsp = TensorShapeProto::new();
350 tsp.set_dim(
351 output_shape
352 .dims
353 .iter()
354 .map(|d| {
355 let mut tspd = TensorShapeProto_Dimension::new();
356 tspd.set_dim_value(*d as i64);
357 tspd
358 })
359 .collect(),
360 );
361 ttp.set_shape(tsp);
362 tip.set_tensor_type(ttp);
363 vip.set_field_type(tip);
364 graph.value_info.push(vip);
365 }
366
367 let can_fold = should_fold_constants && {
369 let all_inputs_are_constant = node
370 .input
371 .iter()
372 .all(|input_name| initializers.contains_key(input_name));
373 let is_known_shape_node =
374 node.get_op_type() == "Shape" && shapes.contains_key(&node.input[0]);
375 all_inputs_are_constant || is_known_shape_node
376 };
377
378 if can_fold {
379 log::debug!("node '{}' can be folded", node.get_name());
380
381 let inputs: Vec<InputTensor> = node
383 .input
384 .iter()
385 .map(|input_name| {
386 if let Some(initializer) = initializers.get(input_name) {
387 InputTensor::try_from(initializer.as_ref())
388 } else {
389 Ok(InputTensor::I64(Cow::Owned(vec![])))
392 }
393 })
394 .collect::<Result<_, _>>()
395 .map_err(|x| {
396 ShapeInferenceError::ConstantFoldingError(
397 ConstantFoldingError::UnsupportedDataType(x),
398 )
399 })?;
400
401 if let Some(mut constant_output) = calculate_constant_node_outputs(
402 node,
403 &shapes,
404 &inputs,
405 &output_shapes,
406 &initializers,
407 opset_version,
408 )
409 .await
410 .map_err(ShapeInferenceError::ConstantFoldingError)?
411 {
412 for (output_index, output_name) in node.output.iter().enumerate().rev() {
414 let output_tensor = constant_output.remove(output_index);
415
416 let output_shape = &output_shapes[output_index];
417 let mut initializer: TensorProto = TensorProto::from(
418 output_tensor,
419 output_shape.dims.iter().map(|x| *x as i64).collect(),
420 );
421 initializer.set_name(output_name.clone());
422 initializer.set_dims(output_shape.dims.iter().map(|x| *x as i64).collect());
423 initializers.insert(output_name.clone(), Cow::Owned(initializer));
424
425 assert_eq!(
426 &shapes[output_name], output_shape,
427 "output shape should be the same after folding"
428 );
429 folded_node_indexes.push(node_index);
430
431 log::info!(
432 "folded output '{output_name}' (#{output_index}) of node {} shape={output_shape}",
433 node.get_name(),
434 );
435 }
436 } else {
437 foldable_nodes.push(node.get_name().to_string());
438 }
439 }
440 }
441 }
442
443 folded_node_indexes.sort();
445 for index in folded_node_indexes.iter().rev() {
446 graph.node.remove(*index);
447 }
448
449 let new_initializers: Vec<TensorProto> = initializers
451 .into_iter()
452 .flat_map(|(_, x)| match x {
453 Cow::Owned(z) => Some(z),
454 Cow::Borrowed(_) => None,
455 })
456 .collect();
457
458 for new_initializer in new_initializers {
459 graph.initializer.push(new_initializer);
460 }
461
462 if !foldable_nodes.is_empty() {
464 log::info!(
465 "The following nodes can likely be folded, but currently aren't due to missing support: {}",
466 foldable_nodes.join(", ")
467 );
468 }
469
470 Ok(())
471}
472
473pub(crate) fn infer_output_shapes(
474 node: &NodeProto,
475 input_shapes: &[&Shape],
476 initializers: &HashMap<String, Cow<TensorProto>>,
477) -> Result<Vec<Shape>, ShapeInferenceError> {
478 match (
479 node.get_op_type(),
480 input_shapes.len(),
481 node.get_output().len(),
482 ) {
483 ("Clip", 1..=3, 1)
484 | (
485 "Identity" | "Sqrt" | "Relu" | "LeakyRelu" | "Abs" | "Acos" | "Acosh" | "Asin" | "Sin"
486 | "Asinh" | "Atan" | "Atanh" | "Cos" | "Cosh" | "Elu" | "Erf" | "Exp" | "Log" | "Neg"
487 | "Ceil" | "Floor" | "Reciprocal" | "Celu" | "Sign",
488 1,
489 1,
490 ) => Ok(vec![input_shapes[0].clone()]),
491
492 ("Cast", 1, 1) => {
493 let to_value: i64 = node
494 .get_attribute_value("to", None)
495 .map_err(ShapeInferenceError::MissingAttribute)?;
496 let to_data_type = ScalarType::from_i32(to_value as i32).map_err(|_| {
497 ShapeInferenceError::InvalidNode(
498 node.get_name().to_string(),
499 format!(
500 "invalid value for to attribute ({}) for Cast operator",
501 to_value
502 ),
503 )
504 })?;
505
506 let mut output_shape = input_shapes[0].clone();
507 output_shape.data_type = to_data_type;
508
509 Ok(vec![output_shape])
510 }
511
512 ("Flatten", 1, 1) => {
513 let axis: usize = {
514 let a = node.get_attribute_value("axis", Some(1)).unwrap();
515 if a < 0 {
516 (a + input_shapes[0].rank() as i64) as usize
517 } else {
518 a as usize
519 }
520 };
521 if axis > input_shapes[0].rank() {
522 return Err(ShapeInferenceError::InvalidNode(
523 node.get_name().to_string(),
524 format!("Flatten axis attribute ({axis}) should be less than or equal to rank of input ({})",input_shapes[0].rank()),
525 ));
526 }
527 let input_dims = &input_shapes[0].dims;
528 let outer_dim = if axis == 0 {
529 1
530 } else {
531 input_dims[0..=(axis - 1)].iter().product::<u64>() as i64
532 };
533 let inner_dim = input_dims[axis..].iter().product::<u64>() as i64;
534
535 let new_dims = vec![outer_dim, inner_dim];
536 Ok(vec![Shape::from(input_shapes[0].data_type, &new_dims)])
537 }
538
539 ("GlobalAveragePool", 1, 1) => {
540 let mut output_shape = input_shapes[0].clone();
541 if output_shape.rank() < 2 {
542 return Err(ShapeInferenceError::InvalidNode(
543 node.get_name().to_string(),
544 format!("invalid input rank for GlobalAveragePool: {output_shape}",),
545 ));
546 }
547 for a in 2..output_shape.dims.len() {
548 output_shape.dims[a] = 1;
549 }
550 Ok(vec![output_shape])
551 }
552
553 ("Gather", 2, 1) => {
554 let r = input_shapes[0].rank() as i64;
556 if r < 1 {
557 return Err(ShapeInferenceError::InvalidNode(
558 node.get_name().to_string(),
559 "data tensor must have rank 1 or greater".to_string(),
560 ));
561 }
562 let q = input_shapes[1].rank() as i64;
563 let mut axis = node
564 .get_attribute_value("axis", Some(0))
565 .map_err(ShapeInferenceError::MissingAttribute)?;
566 if axis >= r || axis < -r {
567 return Err(ShapeInferenceError::InvalidNode(
568 node.get_name().to_string(),
569 "axis must be less than data tensor rank".to_string(),
570 ));
571 }
572
573 if axis < 0 {
574 axis += r;
575 }
576 let out_rank = q + r - 1;
577 Ok(vec![Shape::from(
578 input_shapes[0].data_type,
579 (0..out_rank)
580 .map(|idx| {
581 if idx < axis {
582 input_shapes[0].dim(idx as usize) as i64
583 } else if idx >= axis && idx < (axis + q) {
584 input_shapes[1].dim((idx - axis) as usize) as i64
585 } else {
586 input_shapes[0].dim((idx - q + 1) as usize) as i64
587 }
588 })
589 .collect::<Vec<i64>>()
590 .as_ref(),
591 )])
592 }
593
594 ("Shape", 1, 1) => {
595 let rank = input_shapes[0].rank() as i64;
596 let mut start: i64 = node.get_attribute_value("start", Some(0)).unwrap();
597 let mut end: i64 = node.get_attribute_value("end", Some(rank)).unwrap();
598 if start < 0 {
599 start += rank;
600 }
601 if end < 0 {
602 end += rank;
603 }
604
605 Ok(vec![Shape::from(
606 ScalarType::I64,
607 &[rank.clamp(start, end)],
608 )])
609 }
610
611 ("Size", 1, 1) => Ok(vec![Shape::from(ScalarType::I64, &[1])]),
612
613 ("Slice", num_inputs @ 3..=5, 1) => {
614 let data_shape = input_shapes[0];
615
616 let mut starts: Vec<i64> =
619 static_initializer_value_i64(initializers, &node.get_input()[1])?
620 .iter()
621 .enumerate()
622 .map(|(idx, s)| {
623 if *s < 0 {
624 *s + data_shape.dim(idx) as i64
625 } else {
626 *s
627 }
628 })
629 .collect();
630 if starts.is_empty() {
631 log::warn!(
632 "starts not set for Slice, generating it... name={}",
633 node.get_input()[1]
634 );
635 starts = (0..data_shape.rank()).map(|_| 1).collect();
636 }
637 let mut ends: Vec<i64> =
638 static_initializer_value_i64(initializers, &node.get_input()[2])?
639 .iter()
640 .enumerate()
641 .map(|(idx, s)| {
642 if *s < 0 {
643 *s + data_shape.dim(idx) as i64
644 } else {
645 *s
646 }
647 })
648 .collect();
649 if ends.is_empty() {
650 log::warn!("ends not set for Slice, generating it...");
651 ends = data_shape.dims.iter().map(|x| *x as i64).collect();
652 }
653
654 let axes: Vec<i64> = if num_inputs > 3 {
656 let x: Vec<i64> =
657 static_initializer_value_i64(initializers, &node.get_input()[3])?.into();
658 if x.is_empty() {
659 (0..(data_shape.rank() as i64)).collect()
660 } else {
661 x
662 }
663 } else {
664 (0..(data_shape.rank() as i64)).collect()
665 };
666
667 let steps: Vec<i64> = if num_inputs > 4 {
669 static_initializer_value_i64(initializers, &node.get_input()[4])?.into()
670 } else {
671 log::debug!(
672 "steps not set for slice, generating it (data_shape rank={})",
673 data_shape.rank()
674 );
675 (0..(data_shape.rank() as i64)).map(|_| 1).collect()
676 };
677
678 if axes.len() != steps.len() {
679 return Err(ShapeInferenceError::InvalidNode(node.get_name().to_string(), format!("length of axes attribute ({}) must be equal to length of steps attribute ({})", axes.len(), steps.len())));
680 }
681
682 let axes: Vec<i64> = axes
684 .into_iter()
685 .map(|x| {
686 if x < 0 {
687 x + data_shape.rank() as i64
688 } else {
689 x
690 }
691 })
692 .collect();
693
694 let mut output_shape: Vec<i64> =
695 input_shapes[0].dims.iter().map(|x| *x as i64).collect();
696
697 for (axis_index, axis) in axes.iter().enumerate() {
699 let mut start = starts[axis_index];
700 let mut end = ends[axis_index];
701 let mut step = steps[axis_index];
702 process_slice_inputs(
703 data_shape.dim(*axis as usize) as i64,
704 &mut start,
705 &mut end,
706 &mut step,
707 )?;
708 let temp = div_ceil(end - start, step).max(0);
709 output_shape[*axis as usize] = temp;
710 }
711
712 Ok(vec![Shape::from(data_shape.data_type, &output_shape)])
713 }
714
715 (
716 "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
717 | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1" | "ReduceProd",
718 1,
719 1,
720 ) => {
721 let noop_with_empty_axes = node
724 .get_attribute_value("noop_with_empty_axes", Some(0))
725 .map_err(ShapeInferenceError::MissingAttribute)?;
726
727 let input_shape = input_shapes[0];
728 let input_ndim = input_shape.rank();
729 let all_axes: Vec<i64> = if noop_with_empty_axes == 0 {
730 (0..(input_shape.dims.len() as i64)).collect()
731 } else {
732 vec![]
733 };
734 let axes: Vec<i64> = node
735 .get_attribute_value("axes", Some(all_axes))
736 .map_err(ShapeInferenceError::MissingAttribute)?
737 .into_iter()
738 .map(|idx| {
739 if idx < 0 {
740 (input_ndim as i64) + idx
741 } else {
742 idx
743 }
744 })
745 .collect();
746 let keep_dims = node
747 .get_attribute_value("keepdims", Some(1))
748 .map_err(ShapeInferenceError::MissingAttribute)?;
749
750 Ok(vec![Shape::from(
751 input_shape.data_type,
752 (0..input_ndim as i64)
753 .flat_map(|i| {
754 if !axes.contains(&i) {
755 vec![input_shape.dim(i as usize) as i64]
756 } else if keep_dims == 1 {
757 vec![1]
758 } else {
759 vec![]
760 }
761 })
762 .collect::<Vec<_>>()
763 .as_ref(),
764 )])
765 }
766
767 ("Sub" | "Pow" | "Add" | "Div" | "Mul" | "Mod", 2, 1) => {
768 if let Some(output_shape) =
769 Shape::multi_broadcast(&[input_shapes[0].clone(), input_shapes[1].clone()])
770 {
771 Ok(vec![output_shape])
772 } else {
773 Err(ShapeInferenceError::InvalidNode(
774 node.get_name().to_string(),
775 format!(
776 "two inputs (left {} shape: {}, right {} shape: {}) must be broadcastable",
777 node.get_input()[0],
778 node.get_input()[1],
779 input_shapes[0],
780 input_shapes[1]
781 ),
782 ))
783 }
784 }
785
786 ("Conv", 2, num_outputs @ 1)
787 | ("Conv", 3, num_outputs @ 1)
788 | ("MaxPool", 1, num_outputs @ 1)
789 | ("MaxPool", 1, num_outputs @ 2)
790 | ("AveragePool", 1, num_outputs @ 1)
791 | ("AveragePool", 1, num_outputs @ 2) => {
792 let use_dilation = true;
794 let require_kernel_shape = matches!(node.get_op_type(), "MaxPool" | "AveragePool");
795 let input_shape = input_shapes[0];
796 if input_shape.rank() < 2 {
797 return Err(ShapeInferenceError::InvalidNode(
798 node.get_name().to_string(),
799 "input shape must have at least two dimensions".to_string(),
800 ));
801 }
802
803 let num_input_dims = input_shape.rank() - 2;
804
805 let dilations: Vec<i64> = if use_dilation && node.has_attribute("dilations") {
807 let dilations_attr: Vec<i64> = node
808 .get_attribute_value("dilations", None)
809 .map_err(ShapeInferenceError::MissingAttribute)?;
810 if dilations_attr.len() != num_input_dims {
811 return Err(ShapeInferenceError::InvalidNode(
812 node.get_name().to_string(),
813 "attribute dilations has incorrect size".to_string(),
814 ));
815 }
816 dilations_attr
817 } else {
818 (0..num_input_dims).map(|_| 1).collect()
819 };
820
821 let strides: Vec<i64> = if use_dilation && node.has_attribute("strides") {
823 let strides_attr: Vec<i64> = node
824 .get_attribute_value("strides", None)
825 .map_err(ShapeInferenceError::MissingAttribute)?;
826 if strides_attr.len() != num_input_dims {
827 return Err(ShapeInferenceError::InvalidNode(
828 node.get_name().to_string(),
829 "attribute strides has incorrect size".to_string(),
830 ));
831 }
832 strides_attr
833 } else {
834 (0..num_input_dims).map(|_| 1).collect()
835 };
836
837 let kernel_shape = if node.has_attribute("kernel_shape") {
839 node.get_attribute_value::<Vec<i64>>("kernel_shape", None)
840 .map_err(ShapeInferenceError::MissingAttribute)?
841 } else if require_kernel_shape {
842 return Err(ShapeInferenceError::InvalidNode(
843 node.get_name().to_string(),
844 "node requires kernel_shape to be set".to_string(),
845 ));
846 } else {
847 input_shapes[1].dims[2..]
849 .iter()
850 .map(|x| *x as i64)
851 .collect()
852 };
853
854 if kernel_shape.len() != num_input_dims {
855 return Err(ShapeInferenceError::InvalidNode(
856 node.get_name().to_string(),
857 "kernel shape rank must be equal to input rank".to_string(),
858 ));
859 }
860
861 let effective_kernel_shape: Vec<i64> = kernel_shape
863 .iter()
864 .enumerate()
865 .map(|(idx, dim)| (*dim - 1) * dilations[idx] + 1)
866 .collect();
867
868 let pads = if node.has_attribute("pads") {
870 let p = node
871 .get_attribute_value::<Vec<i64>>("pads", None)
872 .map_err(ShapeInferenceError::MissingAttribute)?;
873 if p.len() != num_input_dims * 2 {
874 return Err(ShapeInferenceError::InvalidNode(
875 node.get_name().to_string(),
876 "pads attribute has incorrect size".to_string(),
877 ));
878 }
879 p
880 } else {
881 let mut pads: Vec<i64> = (0..num_input_dims * 2).map(|_| 0).collect();
882 let auto_pad = node
883 .get_attribute_value("auto_pad", Some(String::from("VALID")))
884 .unwrap();
885
886 if auto_pad != "VALID" {
887 for i in 0..num_input_dims {
888 let mut residual: i64 = 0;
889 let stride = strides[i];
890
891 if stride > 1 {
892 residual = input_shape.dim(2 + i) as i64;
893 while residual >= stride {
894 residual -= stride;
895 }
896 }
897
898 let mut total_pad = if residual == 0 {
899 effective_kernel_shape[i] - stride
900 } else {
901 effective_kernel_shape[i] - residual
902 };
903 if total_pad < 0 {
904 total_pad = 0;
905 }
906
907 let half_pad_small = total_pad >> 1;
908 let half_pad_big = total_pad - half_pad_small;
909 if auto_pad == "SAME_UPPER" {
910 pads[i] = half_pad_small;
911 pads[i + num_input_dims] = half_pad_big;
912 } else if auto_pad == "SAME_LOWER" {
913 pads[i] = half_pad_big;
914 pads[i + num_input_dims] = half_pad_small;
915 }
916 }
917 }
918 pads
919 };
920
921 let mut output_shape: Vec<i64> = vec![];
923 output_shape.push(input_shape.dim(0) as i64);
924 if require_kernel_shape {
925 output_shape.push(input_shape.dim(1) as i64);
926 } else {
927 if input_shapes[1].rank() < 1 {
928 return Err(ShapeInferenceError::InvalidNode(
929 node.get_name().to_string(),
930 "second input has incorrect rank".to_string(),
931 ));
932 }
933 output_shape.push(input_shapes[1].dim(0) as i64);
934 }
935
936 let kernel_shape_size = kernel_shape.len();
937 for i in 0..kernel_shape_size {
938 let mut effective_input_size: i64 = input_shape.dim(2 + i) as i64;
940 effective_input_size += pads[i];
941 effective_input_size += pads[i + kernel_shape_size];
942
943 let ceil_mode = node.get_attribute_value("ceil_mode", Some(0)).unwrap();
945
946 let strided_kernel_positions = if ceil_mode == 1 {
949 div_ceil(effective_input_size - effective_kernel_shape[i], strides[i])
950 } else {
951 (effective_input_size - effective_kernel_shape[i]) / strides[i]
952 };
953
954 output_shape.push(1 + strided_kernel_positions);
955 }
956
957 let final_output_shape = Shape::from(input_shape.data_type, &output_shape);
959 Ok((0..num_outputs)
960 .map(|_| final_output_shape.clone())
961 .collect())
962 }
963
964 ("ConstantOfShape", 1, 1) => {
965 let shape = static_initializer_value_i64(initializers, &node.get_input()[0])?;
966
967 let value = node
968 .get_attribute_value::<TensorProto>("value", None)
969 .map_err(ShapeInferenceError::MissingAttribute)?;
970
971 let data_type = ScalarType::from_i32(value.get_data_type())
972 .map_err(ShapeInferenceError::UnsupportedDataType)?;
973
974 Ok(vec![Shape::from(data_type, shape)])
975 }
976
977 ("Constant", 0, 1) => {
978 if let Ok(values) = node.get_attribute_value::<Vec<f32>>("value_floats", None) {
979 Ok(vec![Shape::from(ScalarType::F32, &[values.len() as i64])])
980 } else if let Ok(values) = node.get_attribute_value::<Vec<i64>>("value_ints", None) {
981 Ok(vec![Shape::from(ScalarType::I64, &[values.len() as i64])])
982 } else if node.get_attribute_value::<f32>("value_float", None).is_ok() {
983 Ok(vec![Shape::from(ScalarType::F32, &[1])])
984 } else if node.get_attribute_value::<i64>("value_int", None).is_ok() {
985 Ok(vec![Shape::from(ScalarType::I64, &[1])])
986 } else if let Ok(tp) = node.get_attribute_value::<TensorProto>("value", None) {
987 Ok(vec![Shape::from(
988 ScalarType::from_i32(tp.get_data_type()).map_err(|_| {
989 ShapeInferenceError::InvalidNode(
990 node.get_name().to_string(),
991 "invalid tensor data type".to_string(),
992 )
993 })?,
994 tp.get_dims(),
995 )])
996 } else {
997 log::debug!("{:#?}", node);
998 Err(ShapeInferenceError::Unsupported("Constant".to_string()))
999 }
1000 }
1001
1002 ("Reshape", 2, 1) => {
1003 let shape_tensor_name = &node.get_input()[1];
1004
1005 if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1006 let allow_zero = node.get_attribute_value("allowzero", Some(0)).unwrap() == 1;
1007
1008 let shape_tensor_contents = shape_tensor.get_int64_data();
1010
1011 for dim in shape_tensor_contents {
1013 match *dim {
1014 -1 => return Err(ShapeInferenceError::Unsupported(
1015 "Reshape with shape containing a -1 element".to_string(),
1016 )),
1017 i64::MIN..=-1 => return Err(ShapeInferenceError::InvalidNode(
1018 node.get_name().to_string(),
1019 format!("Reshape shape tensor cannot contain negative values except for -1 (contains {})", dim))),
1020 0..=i64::MAX => ()
1021 }
1022 }
1023
1024 let output_shape: Vec<i64> = shape_tensor_contents
1025 .iter()
1026 .enumerate()
1027 .map(|(idx, dim)| {
1028 if *dim == 0 && !allow_zero {
1029 input_shapes[0].dim(idx) as i64
1030 } else {
1031 *dim
1032 }
1033 })
1034 .collect();
1035
1036 if output_shape.iter().product::<i64>() != input_shapes[0].element_count() as i64 {
1037 return Err(ShapeInferenceError::InvalidNode(
1038 node.get_name().to_string(),
1039 format!("Reshape input tensor (element count={}) must have the same number of elements as specified by the new shape ({})", input_shapes[0].element_count(), output_shape.iter().product::<i64>())));
1040 }
1041
1042 Ok(vec![Shape::from(input_shapes[0].data_type, &output_shape)])
1043 } else {
1044 Err(ShapeInferenceError::Unsupported(format!(
1045 "Reshape with dynamic shape tensor (input name is {shape_tensor_name})"
1046 )))
1047 }
1048 }
1049
1050 ("Concat", 1.., 1) => {
1051 let axis = node
1052 .get_attribute_value::<i64>("axis", None)
1053 .map_err(ShapeInferenceError::MissingAttribute)?;
1054
1055 let mut shape: Vec<i64> = input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1057 if axis < -(shape.len() as i64) || axis > (shape.len() - 1) as i64 {
1058 return Err(ShapeInferenceError::InvalidNode(
1059 node.get_name().to_string(),
1060 "axis attribute needs to be smaller than input tensor rank".to_string(),
1061 ));
1062 }
1063
1064 let axis_index = if axis < 0 {
1065 ((shape.len() as i64) + axis) as usize
1066 } else {
1067 axis as usize
1068 };
1069 shape[axis_index] = input_shapes.iter().map(|s| s.dim(axis_index) as i64).sum();
1070 Ok(vec![Shape::from(input_shapes[0].data_type, &shape)])
1071 }
1072
1073 ("Dropout", 1..=3, num_outputs @ 1..=2) => {
1074 let shape = input_shapes[0];
1075 Ok((0..num_outputs).map(|_| shape.clone()).collect())
1076 }
1077
1078 ("Unsqueeze", num_inputs @ 1..=2, 1) => {
1079 let axes: Vec<i64> = if num_inputs == 2 {
1080 let shape_tensor_name = &node.get_input()[1];
1081 if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1082 shape_tensor.get_int64_data().to_vec()
1084 } else {
1085 return Err(ShapeInferenceError::Unsupported(
1086 "Unsqueeze with dynamic axis inputs".to_string(),
1087 ));
1088 }
1089 } else {
1090 node.get_attribute_value("axes", None)
1091 .map_err(ShapeInferenceError::MissingAttribute)?
1092 };
1093
1094 let output_rank = input_shapes[0].rank() + axes.len();
1095 let mut input_shape: Vec<i64> =
1096 input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1097 for i in axes {
1098 let index = if i < 0 {
1099 ((output_rank as i64) + i) as usize
1100 } else {
1101 i as usize
1102 };
1103 input_shape.insert(index, 1);
1104 }
1105
1106 Ok(vec![Shape::from(input_shapes[0].data_type, &input_shape)])
1107 }
1108
1109 ("Range", 3, 1) => {
1110 let start = static_initializer_value_i64(initializers, &node.input[0])?;
1112 let end = static_initializer_value_i64(initializers, &node.input[1])?;
1113 let step = static_initializer_value_i64(initializers, &node.input[2])?;
1114
1115 if start.len() != 1 {
1116 return Err(ShapeInferenceError::InvalidNode(
1117 node.get_name().to_string(),
1118 format!(
1119 "the start input needs to be a scalar, has {} elements",
1120 start.len()
1121 ),
1122 ));
1123 }
1124
1125 if end.len() != 1 {
1126 return Err(ShapeInferenceError::InvalidNode(
1127 node.get_name().to_string(),
1128 format!(
1129 "the end input needs to be a scalar, has {} elements",
1130 end.len()
1131 ),
1132 ));
1133 }
1134
1135 if step.len() != 1 {
1136 return Err(ShapeInferenceError::InvalidNode(
1137 node.get_name().to_string(),
1138 format!(
1139 "the step input needs to be a scalar, has {} elements",
1140 step.len()
1141 ),
1142 ));
1143 }
1144
1145 let element_count = (end[0] - start[0]) / step[0];
1146 Ok(vec![Shape::from(ScalarType::I64, &[element_count])])
1147 }
1148
1149 ("Squeeze", num_inputs @ 1..=2, 1) => {
1150 let has_axes = num_inputs == 2;
1151 let axes: Vec<i64> = if has_axes {
1152 let shape_tensor_name = &node.get_input()[1];
1153 if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1154 shape_tensor.get_int64_data().to_vec()
1156 } else {
1157 return Err(ShapeInferenceError::Unsupported(
1158 "Unsqueeze with dynamic axis inputs".to_string(),
1159 ));
1160 }
1161 } else {
1162 vec![]
1163 };
1164
1165 let output_shape: Vec<i64> = input_shapes[0]
1166 .dims
1167 .iter()
1168 .enumerate()
1169 .flat_map(|(idx, dim)| {
1170 if (has_axes && axes.contains(&(idx as i64))) || (!has_axes && *dim == 1) {
1171 vec![]
1172 } else {
1173 vec![*dim as i64]
1174 }
1175 })
1176 .collect();
1177
1178 Ok(vec![Shape::from(input_shapes[0].data_type, &output_shape)])
1179 }
1180
1181 ("Transpose", 1, 1) => {
1182 let input_dims: Vec<i64> = input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1183 let output_dims: Vec<i64> = match node.get_attribute_value::<Vec<i64>>("perm", None) {
1184 Ok(perm) => perm.iter().map(|idx| input_dims[*idx as usize]).collect(),
1185 Err(_) => input_dims.iter().rev().cloned().collect(),
1186 };
1187 Ok(vec![Shape::from(input_shapes[0].data_type, &output_dims)])
1188 }
1189
1190 ("BatchNormalization", 1.., 1) => {
1191 Ok(vec![input_shapes[0].clone()])
1193 }
1194
1195 (
1196 "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
1197 | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1" | "ReduceProd",
1198 2,
1199 1,
1200 ) => Err(ShapeInferenceError::Unsupported(format!(
1201 "{} with two inputs (axes input not supported)",
1202 node.get_op_type()
1203 ))),
1204
1205 (
1206 "Sub" | "Pow" | "Add" | "Div" | "Mul" | "Identity" | "Sqrt" | "ReduceMean" | "Gather"
1207 | "Constant" | "Relu" | "LeakyRelu" | "MaxPool" | "Conv" | "AveragePool" | "Reshape"
1208 | "Concat" | "Unsqueeze" | "Cast" | "Squeeze" | "Shape" | "Slice" | "Range"
1209 | "ConstantOfShape" | "Transpose" | "Abs" | "Acos" | "Acosh" | "Asin" | "Sin" | "Asinh"
1210 | "Atan" | "Atanh" | "Cos" | "Cosh" | "Elu" | "Erf" | "Exp" | "Log" | "Neg" | "Ceil"
1211 | "Reciprocal" | "Floor" | "Mod" | "Celu" | "ReduceSum" | "ReduceMin" | "ReduceMax"
1212 | "ReduceSumSquare" | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
1213 | "ReduceProd" | "Size" | "Sign",
1214 _,
1215 _,
1216 ) => Err(ShapeInferenceError::InvalidNode(
1217 node.get_name().to_string(),
1218 format!(
1219 "invalid number of inputs ({}) or outputs ({})",
1220 node.get_input().len(),
1221 node.get_output().len()
1222 ),
1223 )),
1224
1225 (op_type, _inputs, _outputs) => {
1226 log::debug!("Shape inference unimplemented for op {op_type} with input shapes {input_shapes:#?}");
1227 Err(ShapeInferenceError::Unsupported(op_type.to_string()))
1228 }
1229 }
1230}
1231
1232fn process_slice_inputs(
1234 input_rank: i64,
1235 start: &mut i64,
1236 end: &mut i64,
1237 step: &mut i64,
1238) -> Result<(), ShapeInferenceError> {
1239 if *step == 0 {
1241 return Err(ShapeInferenceError::InvalidNode(
1242 "".to_string(),
1243 "step value must not be zero for slice".to_string(),
1244 ));
1245 }
1246 if *start < 0 {
1248 *start += input_rank;
1249 }
1250 if *step < 0 {
1251 *start = (*start).clamp(0, input_rank - 1);
1252 } else {
1253 *start = (*start).clamp(0, input_rank);
1254 }
1255
1256 if *end < 0 {
1258 *end += input_rank;
1259 }
1260 if *step < 0 {
1261 *end = (*end).clamp(-1, input_rank - 1);
1262 } else {
1263 *end = (*end).clamp(0, input_rank);
1264 }
1265 Ok(())
1266}
1267
1268fn fix_raw_tensor(tensor: &mut TensorProto) -> Result<(), ShapeInferenceError> {
1271 if tensor.has_raw_data() {
1272 let raw_data = tensor.take_raw_data();
1273 match ScalarType::from_i32(tensor.get_data_type())
1274 .map_err(ShapeInferenceError::UnsupportedDataType)?
1275 {
1276 ScalarType::F32 => tensor.set_float_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1277 ScalarType::I64 => tensor.set_int64_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1278 ScalarType::I32 => tensor.set_int32_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1279 ScalarType::U8 => tensor.set_raw_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1280 }
1281 }
1282 Ok(())
1283}
1284
1285#[cfg(test)]
1286mod tests {
1287 use std::collections::HashSet;
1288
1289 use protobuf::Message;
1290 use wonnx::onnx::ModelProto;
1291
1292 use crate::shape_inference::infer_shapes;
1293
1294 use super::dimensions_infos;
1295
1296 async fn test_shape_inference_for_model(path: &str, should_fold_constants: bool) {
1298 let mut model =
1299 ModelProto::parse_from_bytes(&std::fs::read(path).expect("ONNX Model path not found."))
1300 .unwrap();
1301
1302 let graph = model.mut_graph();
1303 let infos = dimensions_infos(graph).unwrap();
1304 graph.value_info.clear();
1305 infer_shapes(graph, should_fold_constants, 13)
1306 .await
1307 .unwrap();
1308 let new_infos = dimensions_infos(graph).unwrap();
1309
1310 let keys_in_old: HashSet<String> = infos.keys().cloned().collect();
1311 let keys_in_new: HashSet<String> = new_infos.keys().cloned().collect();
1312 let all_keys: HashSet<String> = keys_in_old.union(&keys_in_new).cloned().collect();
1313
1314 for key in all_keys {
1315 if !keys_in_old.contains(&key) || !keys_in_new.contains(&key) || infos[&key].is_empty()
1316 {
1317 } else {
1321 assert_eq!(
1322 infos[&key], new_infos[&key],
1323 "different shape inferred for {key}"
1324 )
1325 }
1326 }
1327 }
1328
1329 #[test]
1330 fn test_shape_inference() {
1331 let _ = env_logger::builder().is_test(true).try_init();
1332
1333 pollster::block_on(async {
1334 test_shape_inference_for_model("../data/models/opt-mnist.onnx", false).await;
1335 test_shape_inference_for_model("../data/models/opt-squeeze.onnx", false).await;
1336 test_shape_inference_for_model("../data/models/single_relu.onnx", false).await;
1337 test_shape_inference_for_model("../data/models/single_relu.onnx", false).await;
1338 test_shape_inference_for_model("../data/models/mobilenetv2-7.onnx", true).await;
1339 });
1340 }
1341}