wonnx_preprocessing/
shape_inference.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use protobuf::ProtobufEnum;
4use thiserror::Error;
5use wonnx::{
6    onnx::{
7        GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension,
8        TypeProto, TypeProto_Tensor, TypeProto_oneof_value, ValueInfoProto,
9    },
10    utils::{
11        AttributeNotFoundError, DataTypeError, InputTensor, NodeAttributes, ScalarType, Shape,
12    },
13};
14
15use crate::constant_folding::{calculate_constant_node_outputs, ConstantFoldingError};
16
17pub fn apply_dynamic_dimensions(graph: &mut GraphProto, dynamic_dims: &HashMap<String, i64>) {
18    // Apply to values
19    for value_info in graph.mut_value_info() {
20        apply_dynamic_dimensions_value(value_info, dynamic_dims);
21    }
22
23    for value_info in graph.mut_input() {
24        apply_dynamic_dimensions_value(value_info, dynamic_dims);
25    }
26
27    for value_info in graph.mut_output() {
28        apply_dynamic_dimensions_value(value_info, dynamic_dims);
29    }
30}
31
32/// Divide a number by the indicated dividend, then round up to the next multiple of the dividend if there is a rest.
33fn div_ceil(num: i64, div: i64) -> i64 {
34    num / div + (num % div != 0) as i64
35}
36
37/// Retrieve the value of the initializer with the given name as a vector if i64 values.
38fn static_initializer_value_i64<'a>(
39    initializers: &'a HashMap<String, Cow<'a, TensorProto>>,
40    name: &str,
41) -> Result<&'a [i64], ShapeInferenceError> {
42    if let Some(shape_tensor) = initializers.get(name) {
43        if shape_tensor.get_data_type() != ScalarType::I64.to_datatype().value() {
44            return Err(ShapeInferenceError::Unsupported(format!(
45                "initializer {} has data type {} and not int64, which is currently not supported",
46                name,
47                shape_tensor.get_data_type()
48            )));
49        }
50
51        let expected_value_count: i64 = shape_tensor.get_dims().iter().product();
52
53        // Read data from the int64_data field, except when that field's contents don't match with what we expect; then try
54        // the raw_data field.
55        if shape_tensor.get_int64_data().len() != expected_value_count as usize {
56            let raw_data = shape_tensor.get_raw_data();
57            if raw_data.len() / 8 == expected_value_count as usize {
58                // Raw data has the required size, use that (raw data should be little-endian, see https://github.com/onnx/onnx/issues/2825)
59                log::warn!(
60                    "int64 data for initializer {name} contains {} values, expected {expected_value_count}. Raw data length ({}) matches however, using that. dims={:?}",
61                    shape_tensor.get_int64_data().len(),
62                    shape_tensor.get_raw_data().len(),
63                    shape_tensor.get_dims()
64                );
65                return Ok(bytemuck::cast_slice(raw_data));
66            } else {
67                log::warn!(
68                    "int64 data for initializer {name} contains {} values, expected {expected_value_count}. Raw data length ({}) doesn't match either! dims={:?}",
69                    shape_tensor.get_int64_data().len(),
70                    shape_tensor.get_raw_data().len(),
71                    shape_tensor.get_dims()
72                );
73            }
74        }
75
76        // Get the tensor's contents
77        Ok(shape_tensor.get_int64_data())
78    } else {
79        Err(ShapeInferenceError::Unsupported(format!(
80            "input {} is dynamic (only static initializers are supported)",
81            name
82        )))
83    }
84}
85
86/// Replaces dimension params with provided values
87fn apply_dynamic_dimensions_value(
88    value_info: &mut ValueInfoProto,
89    dynamic_dims: &HashMap<String, i64>,
90) {
91    let name = value_info.get_name().to_string();
92    let field_type = value_info.mut_field_type();
93
94    if let Some(TypeProto_oneof_value::tensor_type(field_type_value)) = &mut field_type.value {
95        let dims = field_type_value.mut_shape().mut_dim();
96
97        for (idx, dim) in dims.iter_mut().enumerate() {
98            if let Some(new_dim_value) = dynamic_dims.get(dim.get_dim_param()) {
99                println!(
100                    "Setting dimension param {idx} ({}) to value {new_dim_value} for {name}",
101                    dim.get_dim_param()
102                );
103                dim.clear_dim_param();
104                dim.set_dim_value(*new_dim_value);
105            }
106        }
107    }
108}
109
110/// Retrieve all fully known value shapes from a graph
111pub(crate) fn dimensions_infos(
112    graph_proto: &GraphProto,
113) -> Result<HashMap<String, Shape>, DataTypeError> {
114    let mut shapes_info = HashMap::new();
115
116    for info in graph_proto.get_input() {
117        if let Ok(shape) = info.get_shape() {
118            shapes_info.insert(info.get_name().to_string(), shape);
119        }
120    }
121
122    for info in graph_proto.get_output() {
123        if let Ok(shape) = info.get_shape() {
124            if shapes_info
125                .insert(info.get_name().to_string(), shape)
126                .is_some()
127            {
128                log::warn!(
129                    "already had shape information for '{}', replacing from outputs",
130                    info.get_name()
131                );
132            }
133        }
134    }
135
136    for info in graph_proto.get_value_info() {
137        if let Ok(shape) = info.get_shape() {
138            if shapes_info
139                .insert(info.get_name().to_string(), shape)
140                .is_some()
141            {
142                log::warn!(
143                    "already had shape information for '{}', replacing from value_info",
144                    info.get_name()
145                );
146            }
147        }
148    }
149
150    for info in graph_proto.get_initializer() {
151        if let Ok(data_type) = ScalarType::from_i32(info.get_data_type()) {
152            let shape = Shape::from(data_type, info.get_dims());
153            if shapes_info
154                .insert(info.get_name().to_string(), shape)
155                .is_some()
156            {
157                log::warn!(
158                    "already shape information for '{}', replacing from initializer",
159                    info.get_name()
160                );
161            }
162        }
163    }
164
165    Ok(shapes_info)
166}
167
168#[derive(Error, Debug)]
169pub enum ShapeInferenceError {
170    #[error("missing shape for input {0}")]
171    MissingInputShape(String),
172
173    #[error("incomplete or missing shape for input {0} - be sure to specify all dynamic dimension parameters")]
174    IncompleteInputShape(String),
175
176    #[error("unsupported: {0}")]
177    Unsupported(String),
178
179    #[error("node {0} is invalid: {1}")]
180    InvalidNode(String, String),
181
182    #[error("attribute {0} required for shape inference is missing")]
183    #[from(AttributeNotFoundError)]
184    MissingAttribute(AttributeNotFoundError),
185
186    #[error("unsupported data type encountered: {0}")]
187    #[from(DataTypeError)]
188    UnsupportedDataType(DataTypeError),
189
190    #[error("constant folding failed: {0}")]
191    #[from(ConstantFoldingError)]
192    ConstantFoldingError(ConstantFoldingError),
193}
194
195/// Replaces nodes of op type Constant with an initializer
196fn replace_constant_ops_with_initializers(
197    graph: &mut GraphProto,
198) -> Result<(), ShapeInferenceError> {
199    for node_index in (0..graph.node.len()).rev() {
200        let is_constant = graph.node[node_index].get_op_type() == "Constant";
201
202        if is_constant {
203            {
204                let node = &graph.node[node_index];
205                if node.get_output().len() != 1 {
206                    return Err(ShapeInferenceError::InvalidNode(
207                        node.get_name().to_string(),
208                        format!(
209                            "Constant op must have one output, has {}",
210                            node.get_output().len()
211                        ),
212                    ));
213                }
214
215                // Create an initializer
216                let mut initializer = TensorProto::new();
217
218                // Get constant value
219                if let Ok(values) = node.get_attribute_value::<Vec<f32>>("value_floats", None) {
220                    initializer.set_data_type(ScalarType::F32.to_datatype().value());
221                    initializer.set_dims(vec![values.len() as i64]);
222                    initializer.set_float_data(values);
223                } else if let Ok(values) = node.get_attribute_value::<Vec<i64>>("value_ints", None)
224                {
225                    initializer.set_data_type(ScalarType::I64.to_datatype().value());
226                    initializer.set_dims(vec![values.len() as i64]);
227                    initializer.set_int64_data(values);
228                } else if let Ok(values) = node.get_attribute_value::<i64>("value_int", None) {
229                    initializer.set_int64_data(vec![values]);
230                    initializer.set_data_type(ScalarType::I64.to_datatype().value());
231                    initializer.set_dims(vec![1]);
232                } else if let Ok(values) = node.get_attribute_value::<f32>("value_float", None) {
233                    initializer.set_float_data(vec![values]);
234                    initializer.set_data_type(ScalarType::F32.to_datatype().value());
235                    initializer.set_dims(vec![1]);
236                } else if let Ok(tp) = node.get_attribute_value::<TensorProto>("value", None) {
237                    initializer = tp;
238                    fix_raw_tensor(&mut initializer)?;
239                } else {
240                    log::debug!("Constant node attributes: {:?}", node.attribute);
241                    return Err(ShapeInferenceError::Unsupported(
242                        "Constant node with data types other than float, int".to_string(),
243                    ));
244                }
245
246                log::info!(
247                    "Replacing Constant node '{}' with an initializer (name='{}', shape={:?})",
248                    node.get_name(),
249                    node.output[0].clone(),
250                    initializer.dims
251                );
252
253                initializer.set_name(node.output[0].clone()); // Needs to happen here because the name can be overwritten above when there is a tensor in the "value" attribute
254                graph.initializer.push(initializer);
255            }
256            graph.node.remove(node_index);
257        }
258    }
259    Ok(())
260}
261
262pub async fn infer_shapes(
263    graph: &mut GraphProto,
264    should_fold_constants: bool,
265    opset_version: i64,
266) -> Result<(), ShapeInferenceError> {
267    let mut foldable_nodes: Vec<String> = vec![];
268    let mut folded_node_indexes: Vec<usize> = vec![];
269
270    if should_fold_constants {
271        replace_constant_ops_with_initializers(graph)?;
272    }
273
274    let mut shapes = dimensions_infos(graph).map_err(ShapeInferenceError::UnsupportedDataType)?;
275
276    // Needed for Reshape
277    let mut initializers: HashMap<String, Cow<TensorProto>> = HashMap::from_iter(
278        graph
279            .initializer
280            .iter()
281            .map(|x| (x.get_name().to_string(), Cow::Borrowed(x))),
282    );
283
284    for (node_index, node) in graph.node.iter().enumerate() {
285        log::debug!(
286            "node: {} {} inputs {} -> outputs {}",
287            node.get_op_type(),
288            node.get_name(),
289            node.get_input().join(", "),
290            node.get_output().join(", ")
291        );
292
293        // Do shape inference if this node has at least one output for which the shape is not yet known
294        if node
295            .get_output()
296            .iter()
297            .any(|output_name| !shapes.contains_key(output_name.as_str()))
298        {
299            log::debug!("node needs shape inference: {}", node.get_name());
300
301            let input_shapes: Vec<&Shape> = node
302                .get_input()
303                .iter()
304                .map(|name| {
305                    shapes
306                        .get(name)
307                        .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))
308                })
309                .collect::<Result<_, ShapeInferenceError>>()?;
310
311            let output_shapes = infer_output_shapes(node, &input_shapes, &initializers)?;
312
313            // Check inferred shapes
314            for (output_index, shape) in output_shapes.iter().enumerate() {
315                if shape.rank() == 0 {
316                    log::warn!(
317                        "inferred shape for output {output_index} of node '{}' is empty: {shape}",
318                        node.get_name()
319                    );
320                }
321            }
322
323            log::info!(
324                "node {} inferred output shapes: {}",
325                node.get_name(),
326                output_shapes
327                    .iter()
328                    .enumerate()
329                    .map(|(idx, x)| format!("{}={x}", node.output[idx]))
330                    .collect::<Vec<String>>()
331                    .join(", ")
332            );
333
334            if output_shapes.len() != node.get_output().len() {
335                panic!("number of outputs inferred does not match node output count");
336            }
337
338            // Cache the inferred shapes and write to model
339            for (output_idx, output_name) in node.get_output().iter().enumerate() {
340                let output_shape = &output_shapes[output_idx];
341                shapes.insert(output_name.clone(), output_shape.clone());
342                let mut vip = ValueInfoProto::new();
343                vip.set_name(output_name.clone());
344
345                let mut tip = TypeProto::new();
346                let mut ttp = TypeProto_Tensor::new();
347                ttp.set_elem_type(output_shape.data_type.to_datatype().value());
348
349                let mut tsp = TensorShapeProto::new();
350                tsp.set_dim(
351                    output_shape
352                        .dims
353                        .iter()
354                        .map(|d| {
355                            let mut tspd = TensorShapeProto_Dimension::new();
356                            tspd.set_dim_value(*d as i64);
357                            tspd
358                        })
359                        .collect(),
360                );
361                ttp.set_shape(tsp);
362                tip.set_tensor_type(ttp);
363                vip.set_field_type(tip);
364                graph.value_info.push(vip);
365            }
366
367            // Can we fold the node altogether?
368            let can_fold = should_fold_constants && {
369                let all_inputs_are_constant = node
370                    .input
371                    .iter()
372                    .all(|input_name| initializers.contains_key(input_name));
373                let is_known_shape_node =
374                    node.get_op_type() == "Shape" && shapes.contains_key(&node.input[0]);
375                all_inputs_are_constant || is_known_shape_node
376            };
377
378            if can_fold {
379                log::debug!("node '{}' can be folded", node.get_name());
380
381                // Collect constant inputs
382                let inputs: Vec<InputTensor> = node
383                    .input
384                    .iter()
385                    .map(|input_name| {
386                        if let Some(initializer) = initializers.get(input_name) {
387                            InputTensor::try_from(initializer.as_ref())
388                        } else {
389                            // This should only happen when is_known_shape is true. In this case we will not do any GPU inference
390                            // and the contents if this tensor don't matter
391                            Ok(InputTensor::I64(Cow::Owned(vec![])))
392                        }
393                    })
394                    .collect::<Result<_, _>>()
395                    .map_err(|x| {
396                        ShapeInferenceError::ConstantFoldingError(
397                            ConstantFoldingError::UnsupportedDataType(x),
398                        )
399                    })?;
400
401                if let Some(mut constant_output) = calculate_constant_node_outputs(
402                    node,
403                    &shapes,
404                    &inputs,
405                    &output_shapes,
406                    &initializers,
407                    opset_version,
408                )
409                .await
410                .map_err(ShapeInferenceError::ConstantFoldingError)?
411                {
412                    // Save constant outputs as initializers
413                    for (output_index, output_name) in node.output.iter().enumerate().rev() {
414                        let output_tensor = constant_output.remove(output_index);
415
416                        let output_shape = &output_shapes[output_index];
417                        let mut initializer: TensorProto = TensorProto::from(
418                            output_tensor,
419                            output_shape.dims.iter().map(|x| *x as i64).collect(),
420                        );
421                        initializer.set_name(output_name.clone());
422                        initializer.set_dims(output_shape.dims.iter().map(|x| *x as i64).collect());
423                        initializers.insert(output_name.clone(), Cow::Owned(initializer));
424
425                        assert_eq!(
426                            &shapes[output_name], output_shape,
427                            "output shape should be the same after folding"
428                        );
429                        folded_node_indexes.push(node_index);
430
431                        log::info!(
432                            "folded output '{output_name}' (#{output_index}) of node {} shape={output_shape}",
433                            node.get_name(),
434                        );
435                    }
436                } else {
437                    foldable_nodes.push(node.get_name().to_string());
438                }
439            }
440        }
441    }
442
443    // Remove folded nodes
444    folded_node_indexes.sort();
445    for index in folded_node_indexes.iter().rev() {
446        graph.node.remove(*index);
447    }
448
449    // Save newly created initializers
450    let new_initializers: Vec<TensorProto> = initializers
451        .into_iter()
452        .flat_map(|(_, x)| match x {
453            Cow::Owned(z) => Some(z),
454            Cow::Borrowed(_) => None,
455        })
456        .collect();
457
458    for new_initializer in new_initializers {
459        graph.initializer.push(new_initializer);
460    }
461
462    // Notify about missing fold implementations
463    if !foldable_nodes.is_empty() {
464        log::info!(
465            "The following nodes can likely be folded, but currently aren't due to missing support: {}",
466            foldable_nodes.join(", ")
467        );
468    }
469
470    Ok(())
471}
472
473pub(crate) fn infer_output_shapes(
474    node: &NodeProto,
475    input_shapes: &[&Shape],
476    initializers: &HashMap<String, Cow<TensorProto>>,
477) -> Result<Vec<Shape>, ShapeInferenceError> {
478    match (
479        node.get_op_type(),
480        input_shapes.len(),
481        node.get_output().len(),
482    ) {
483        ("Clip", 1..=3, 1)
484        | (
485            "Identity" | "Sqrt" | "Relu" | "LeakyRelu" | "Abs" | "Acos" | "Acosh" | "Asin" | "Sin"
486            | "Asinh" | "Atan" | "Atanh" | "Cos" | "Cosh" | "Elu" | "Erf" | "Exp" | "Log" | "Neg"
487            | "Ceil" | "Floor" | "Reciprocal" | "Celu" | "Sign",
488            1,
489            1,
490        ) => Ok(vec![input_shapes[0].clone()]),
491
492        ("Cast", 1, 1) => {
493            let to_value: i64 = node
494                .get_attribute_value("to", None)
495                .map_err(ShapeInferenceError::MissingAttribute)?;
496            let to_data_type = ScalarType::from_i32(to_value as i32).map_err(|_| {
497                ShapeInferenceError::InvalidNode(
498                    node.get_name().to_string(),
499                    format!(
500                        "invalid value for to attribute ({}) for Cast operator",
501                        to_value
502                    ),
503                )
504            })?;
505
506            let mut output_shape = input_shapes[0].clone();
507            output_shape.data_type = to_data_type;
508
509            Ok(vec![output_shape])
510        }
511
512        ("Flatten", 1, 1) => {
513            let axis: usize = {
514                let a = node.get_attribute_value("axis", Some(1)).unwrap();
515                if a < 0 {
516                    (a + input_shapes[0].rank() as i64) as usize
517                } else {
518                    a as usize
519                }
520            };
521            if axis > input_shapes[0].rank() {
522                return Err(ShapeInferenceError::InvalidNode(
523                    node.get_name().to_string(),
524                    format!("Flatten axis attribute ({axis}) should be less than or equal to rank of input ({})",input_shapes[0].rank()),
525                ));
526            }
527            let input_dims = &input_shapes[0].dims;
528            let outer_dim = if axis == 0 {
529                1
530            } else {
531                input_dims[0..=(axis - 1)].iter().product::<u64>() as i64
532            };
533            let inner_dim = input_dims[axis..].iter().product::<u64>() as i64;
534
535            let new_dims = vec![outer_dim, inner_dim];
536            Ok(vec![Shape::from(input_shapes[0].data_type, &new_dims)])
537        }
538
539        ("GlobalAveragePool", 1, 1) => {
540            let mut output_shape = input_shapes[0].clone();
541            if output_shape.rank() < 2 {
542                return Err(ShapeInferenceError::InvalidNode(
543                    node.get_name().to_string(),
544                    format!("invalid input rank for GlobalAveragePool: {output_shape}",),
545                ));
546            }
547            for a in 2..output_shape.dims.len() {
548                output_shape.dims[a] = 1;
549            }
550            Ok(vec![output_shape])
551        }
552
553        ("Gather", 2, 1) => {
554            // https://github.com/onnx/onnx/blob/ceaeafa4cd2156c69dd9699bbdd2aa7d39e7c74c/onnx/defs/tensor/defs.cc#L1601
555            let r = input_shapes[0].rank() as i64;
556            if r < 1 {
557                return Err(ShapeInferenceError::InvalidNode(
558                    node.get_name().to_string(),
559                    "data tensor must have rank 1 or greater".to_string(),
560                ));
561            }
562            let q = input_shapes[1].rank() as i64;
563            let mut axis = node
564                .get_attribute_value("axis", Some(0))
565                .map_err(ShapeInferenceError::MissingAttribute)?;
566            if axis >= r || axis < -r {
567                return Err(ShapeInferenceError::InvalidNode(
568                    node.get_name().to_string(),
569                    "axis must be less than data tensor rank".to_string(),
570                ));
571            }
572
573            if axis < 0 {
574                axis += r;
575            }
576            let out_rank = q + r - 1;
577            Ok(vec![Shape::from(
578                input_shapes[0].data_type,
579                (0..out_rank)
580                    .map(|idx| {
581                        if idx < axis {
582                            input_shapes[0].dim(idx as usize) as i64
583                        } else if idx >= axis && idx < (axis + q) {
584                            input_shapes[1].dim((idx - axis) as usize) as i64
585                        } else {
586                            input_shapes[0].dim((idx - q + 1) as usize) as i64
587                        }
588                    })
589                    .collect::<Vec<i64>>()
590                    .as_ref(),
591            )])
592        }
593
594        ("Shape", 1, 1) => {
595            let rank = input_shapes[0].rank() as i64;
596            let mut start: i64 = node.get_attribute_value("start", Some(0)).unwrap();
597            let mut end: i64 = node.get_attribute_value("end", Some(rank)).unwrap();
598            if start < 0 {
599                start += rank;
600            }
601            if end < 0 {
602                end += rank;
603            }
604
605            Ok(vec![Shape::from(
606                ScalarType::I64,
607                &[rank.clamp(start, end)],
608            )])
609        }
610
611        ("Size", 1, 1) => Ok(vec![Shape::from(ScalarType::I64, &[1])]),
612
613        ("Slice", num_inputs @ 3..=5, 1) => {
614            let data_shape = input_shapes[0];
615
616            // All negative values in `starts[i]` and `ends[i]` have `dims[axes[i]]` added to them,
617            // where `dims` are the dimensions of `input`.
618            let mut starts: Vec<i64> =
619                static_initializer_value_i64(initializers, &node.get_input()[1])?
620                    .iter()
621                    .enumerate()
622                    .map(|(idx, s)| {
623                        if *s < 0 {
624                            *s + data_shape.dim(idx) as i64
625                        } else {
626                            *s
627                        }
628                    })
629                    .collect();
630            if starts.is_empty() {
631                log::warn!(
632                    "starts not set for Slice, generating it... name={}",
633                    node.get_input()[1]
634                );
635                starts = (0..data_shape.rank()).map(|_| 1).collect();
636            }
637            let mut ends: Vec<i64> =
638                static_initializer_value_i64(initializers, &node.get_input()[2])?
639                    .iter()
640                    .enumerate()
641                    .map(|(idx, s)| {
642                        if *s < 0 {
643                            *s + data_shape.dim(idx) as i64
644                        } else {
645                            *s
646                        }
647                    })
648                    .collect();
649            if ends.is_empty() {
650                log::warn!("ends not set for Slice, generating it...");
651                ends = data_shape.dims.iter().map(|x| *x as i64).collect();
652            }
653
654            // If `axes` are omitted, they are set to `[0, ..., r-1]`.
655            let axes: Vec<i64> = if num_inputs > 3 {
656                let x: Vec<i64> =
657                    static_initializer_value_i64(initializers, &node.get_input()[3])?.into();
658                if x.is_empty() {
659                    (0..(data_shape.rank() as i64)).collect()
660                } else {
661                    x
662                }
663            } else {
664                (0..(data_shape.rank() as i64)).collect()
665            };
666
667            // If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`
668            let steps: Vec<i64> = if num_inputs > 4 {
669                static_initializer_value_i64(initializers, &node.get_input()[4])?.into()
670            } else {
671                log::debug!(
672                    "steps not set for slice, generating it (data_shape rank={})",
673                    data_shape.rank()
674                );
675                (0..(data_shape.rank() as i64)).map(|_| 1).collect()
676            };
677
678            if axes.len() != steps.len() {
679                return Err(ShapeInferenceError::InvalidNode(node.get_name().to_string(), format!("length of axes attribute ({}) must be equal to length of steps attribute ({})", axes.len(), steps.len())));
680            }
681
682            // All negative elements of `axes` are made non-negatve by adding `r` to them, where`r =rank(input)`.
683            let axes: Vec<i64> = axes
684                .into_iter()
685                .map(|x| {
686                    if x < 0 {
687                        x + data_shape.rank() as i64
688                    } else {
689                        x
690                    }
691                })
692                .collect();
693
694            let mut output_shape: Vec<i64> =
695                input_shapes[0].dims.iter().map(|x| *x as i64).collect();
696
697            // https://github.com/onnx/onnx/blob/fb80e3ade84e9f406711aa41b9f3665753158371/onnx/defs/tensor/defs.cc#L969
698            for (axis_index, axis) in axes.iter().enumerate() {
699                let mut start = starts[axis_index];
700                let mut end = ends[axis_index];
701                let mut step = steps[axis_index];
702                process_slice_inputs(
703                    data_shape.dim(*axis as usize) as i64,
704                    &mut start,
705                    &mut end,
706                    &mut step,
707                )?;
708                let temp = div_ceil(end - start, step).max(0);
709                output_shape[*axis as usize] = temp;
710            }
711
712            Ok(vec![Shape::from(data_shape.data_type, &output_shape)])
713        }
714
715        (
716            "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
717            | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1" | "ReduceProd",
718            1,
719            1,
720        ) => {
721            // https://github.com/onnx/onnx/blob/main/docs/Changelog.md#reducemean-18
722            // Note: up to version 13 these ops take 'axes' as an attribute; from version 18 they take axes as a second (optional) attribute
723            let noop_with_empty_axes = node
724                .get_attribute_value("noop_with_empty_axes", Some(0))
725                .map_err(ShapeInferenceError::MissingAttribute)?;
726
727            let input_shape = input_shapes[0];
728            let input_ndim = input_shape.rank();
729            let all_axes: Vec<i64> = if noop_with_empty_axes == 0 {
730                (0..(input_shape.dims.len() as i64)).collect()
731            } else {
732                vec![]
733            };
734            let axes: Vec<i64> = node
735                .get_attribute_value("axes", Some(all_axes))
736                .map_err(ShapeInferenceError::MissingAttribute)?
737                .into_iter()
738                .map(|idx| {
739                    if idx < 0 {
740                        (input_ndim as i64) + idx
741                    } else {
742                        idx
743                    }
744                })
745                .collect();
746            let keep_dims = node
747                .get_attribute_value("keepdims", Some(1))
748                .map_err(ShapeInferenceError::MissingAttribute)?;
749
750            Ok(vec![Shape::from(
751                input_shape.data_type,
752                (0..input_ndim as i64)
753                    .flat_map(|i| {
754                        if !axes.contains(&i) {
755                            vec![input_shape.dim(i as usize) as i64]
756                        } else if keep_dims == 1 {
757                            vec![1]
758                        } else {
759                            vec![]
760                        }
761                    })
762                    .collect::<Vec<_>>()
763                    .as_ref(),
764            )])
765        }
766
767        ("Sub" | "Pow" | "Add" | "Div" | "Mul" | "Mod", 2, 1) => {
768            if let Some(output_shape) =
769                Shape::multi_broadcast(&[input_shapes[0].clone(), input_shapes[1].clone()])
770            {
771                Ok(vec![output_shape])
772            } else {
773                Err(ShapeInferenceError::InvalidNode(
774                    node.get_name().to_string(),
775                    format!(
776                        "two inputs (left {} shape: {}, right {} shape: {}) must be broadcastable",
777                        node.get_input()[0],
778                        node.get_input()[1],
779                        input_shapes[0],
780                        input_shapes[1]
781                    ),
782                ))
783            }
784        }
785
786        ("Conv", 2, num_outputs @ 1)
787        | ("Conv", 3, num_outputs @ 1)
788        | ("MaxPool", 1, num_outputs @ 1)
789        | ("MaxPool", 1, num_outputs @ 2)
790        | ("AveragePool", 1, num_outputs @ 1)
791        | ("AveragePool", 1, num_outputs @ 2) => {
792            // https://github.com/onnx/onnx/blob/ded7e3a27449750fb429b0f88a494e10fd555be7/onnx/defs/nn/old.cc#L240
793            let use_dilation = true;
794            let require_kernel_shape = matches!(node.get_op_type(), "MaxPool" | "AveragePool");
795            let input_shape = input_shapes[0];
796            if input_shape.rank() < 2 {
797                return Err(ShapeInferenceError::InvalidNode(
798                    node.get_name().to_string(),
799                    "input shape must have at least two dimensions".to_string(),
800                ));
801            }
802
803            let num_input_dims = input_shape.rank() - 2;
804
805            // Obtain dilations info
806            let dilations: Vec<i64> = if use_dilation && node.has_attribute("dilations") {
807                let dilations_attr: Vec<i64> = node
808                    .get_attribute_value("dilations", None)
809                    .map_err(ShapeInferenceError::MissingAttribute)?;
810                if dilations_attr.len() != num_input_dims {
811                    return Err(ShapeInferenceError::InvalidNode(
812                        node.get_name().to_string(),
813                        "attribute dilations has incorrect size".to_string(),
814                    ));
815                }
816                dilations_attr
817            } else {
818                (0..num_input_dims).map(|_| 1).collect()
819            };
820
821            // Obtain stride info
822            let strides: Vec<i64> = if use_dilation && node.has_attribute("strides") {
823                let strides_attr: Vec<i64> = node
824                    .get_attribute_value("strides", None)
825                    .map_err(ShapeInferenceError::MissingAttribute)?;
826                if strides_attr.len() != num_input_dims {
827                    return Err(ShapeInferenceError::InvalidNode(
828                        node.get_name().to_string(),
829                        "attribute strides has incorrect size".to_string(),
830                    ));
831                }
832                strides_attr
833            } else {
834                (0..num_input_dims).map(|_| 1).collect()
835            };
836
837            // Obtain kernel shape
838            let kernel_shape = if node.has_attribute("kernel_shape") {
839                node.get_attribute_value::<Vec<i64>>("kernel_shape", None)
840                    .map_err(ShapeInferenceError::MissingAttribute)?
841            } else if require_kernel_shape {
842                return Err(ShapeInferenceError::InvalidNode(
843                    node.get_name().to_string(),
844                    "node requires kernel_shape to be set".to_string(),
845                ));
846            } else {
847                // Use second input shape to derive kernel shape
848                input_shapes[1].dims[2..]
849                    .iter()
850                    .map(|x| *x as i64)
851                    .collect()
852            };
853
854            if kernel_shape.len() != num_input_dims {
855                return Err(ShapeInferenceError::InvalidNode(
856                    node.get_name().to_string(),
857                    "kernel shape rank must be equal to input rank".to_string(),
858                ));
859            }
860
861            // Determine effective kernel shape
862            let effective_kernel_shape: Vec<i64> = kernel_shape
863                .iter()
864                .enumerate()
865                .map(|(idx, dim)| (*dim - 1) * dilations[idx] + 1)
866                .collect();
867
868            // Obtain pads information
869            let pads = if node.has_attribute("pads") {
870                let p = node
871                    .get_attribute_value::<Vec<i64>>("pads", None)
872                    .map_err(ShapeInferenceError::MissingAttribute)?;
873                if p.len() != num_input_dims * 2 {
874                    return Err(ShapeInferenceError::InvalidNode(
875                        node.get_name().to_string(),
876                        "pads attribute has incorrect size".to_string(),
877                    ));
878                }
879                p
880            } else {
881                let mut pads: Vec<i64> = (0..num_input_dims * 2).map(|_| 0).collect();
882                let auto_pad = node
883                    .get_attribute_value("auto_pad", Some(String::from("VALID")))
884                    .unwrap();
885
886                if auto_pad != "VALID" {
887                    for i in 0..num_input_dims {
888                        let mut residual: i64 = 0;
889                        let stride = strides[i];
890
891                        if stride > 1 {
892                            residual = input_shape.dim(2 + i) as i64;
893                            while residual >= stride {
894                                residual -= stride;
895                            }
896                        }
897
898                        let mut total_pad = if residual == 0 {
899                            effective_kernel_shape[i] - stride
900                        } else {
901                            effective_kernel_shape[i] - residual
902                        };
903                        if total_pad < 0 {
904                            total_pad = 0;
905                        }
906
907                        let half_pad_small = total_pad >> 1;
908                        let half_pad_big = total_pad - half_pad_small;
909                        if auto_pad == "SAME_UPPER" {
910                            pads[i] = half_pad_small;
911                            pads[i + num_input_dims] = half_pad_big;
912                        } else if auto_pad == "SAME_LOWER" {
913                            pads[i] = half_pad_big;
914                            pads[i + num_input_dims] = half_pad_small;
915                        }
916                    }
917                }
918                pads
919            };
920
921            // Determine output shape
922            let mut output_shape: Vec<i64> = vec![];
923            output_shape.push(input_shape.dim(0) as i64);
924            if require_kernel_shape {
925                output_shape.push(input_shape.dim(1) as i64);
926            } else {
927                if input_shapes[1].rank() < 1 {
928                    return Err(ShapeInferenceError::InvalidNode(
929                        node.get_name().to_string(),
930                        "second input has incorrect rank".to_string(),
931                    ));
932                }
933                output_shape.push(input_shapes[1].dim(0) as i64);
934            }
935
936            let kernel_shape_size = kernel_shape.len();
937            for i in 0..kernel_shape_size {
938                // how big is the input, including padding
939                let mut effective_input_size: i64 = input_shape.dim(2 + i) as i64;
940                effective_input_size += pads[i];
941                effective_input_size += pads[i + kernel_shape_size];
942
943                // default is floor mode .i.e. ceil_mode is set to 0
944                let ceil_mode = node.get_attribute_value("ceil_mode", Some(0)).unwrap();
945
946                // how many times we can move the kernel from it's initial position, based
947                // on the stride
948                let strided_kernel_positions = if ceil_mode == 1 {
949                    div_ceil(effective_input_size - effective_kernel_shape[i], strides[i])
950                } else {
951                    (effective_input_size - effective_kernel_shape[i]) / strides[i]
952                };
953
954                output_shape.push(1 + strided_kernel_positions);
955            }
956
957            // MaxPool can have two outputs
958            let final_output_shape = Shape::from(input_shape.data_type, &output_shape);
959            Ok((0..num_outputs)
960                .map(|_| final_output_shape.clone())
961                .collect())
962        }
963
964        ("ConstantOfShape", 1, 1) => {
965            let shape = static_initializer_value_i64(initializers, &node.get_input()[0])?;
966
967            let value = node
968                .get_attribute_value::<TensorProto>("value", None)
969                .map_err(ShapeInferenceError::MissingAttribute)?;
970
971            let data_type = ScalarType::from_i32(value.get_data_type())
972                .map_err(ShapeInferenceError::UnsupportedDataType)?;
973
974            Ok(vec![Shape::from(data_type, shape)])
975        }
976
977        ("Constant", 0, 1) => {
978            if let Ok(values) = node.get_attribute_value::<Vec<f32>>("value_floats", None) {
979                Ok(vec![Shape::from(ScalarType::F32, &[values.len() as i64])])
980            } else if let Ok(values) = node.get_attribute_value::<Vec<i64>>("value_ints", None) {
981                Ok(vec![Shape::from(ScalarType::I64, &[values.len() as i64])])
982            } else if node.get_attribute_value::<f32>("value_float", None).is_ok() {
983                Ok(vec![Shape::from(ScalarType::F32, &[1])])
984            } else if node.get_attribute_value::<i64>("value_int", None).is_ok() {
985                Ok(vec![Shape::from(ScalarType::I64, &[1])])
986            } else if let Ok(tp) = node.get_attribute_value::<TensorProto>("value", None) {
987                Ok(vec![Shape::from(
988                    ScalarType::from_i32(tp.get_data_type()).map_err(|_| {
989                        ShapeInferenceError::InvalidNode(
990                            node.get_name().to_string(),
991                            "invalid tensor data type".to_string(),
992                        )
993                    })?,
994                    tp.get_dims(),
995                )])
996            } else {
997                log::debug!("{:#?}", node);
998                Err(ShapeInferenceError::Unsupported("Constant".to_string()))
999            }
1000        }
1001
1002        ("Reshape", 2, 1) => {
1003            let shape_tensor_name = &node.get_input()[1];
1004
1005            if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1006                let allow_zero = node.get_attribute_value("allowzero", Some(0)).unwrap() == 1;
1007
1008                // Get the tensor's contents
1009                let shape_tensor_contents = shape_tensor.get_int64_data();
1010
1011                // The -1 value is allowed but not supported
1012                for dim in shape_tensor_contents {
1013                    match *dim {
1014						-1 => return Err(ShapeInferenceError::Unsupported(
1015                            "Reshape with shape containing a -1 element".to_string(),
1016                        )),
1017						i64::MIN..=-1 => return Err(ShapeInferenceError::InvalidNode(
1018            			node.get_name().to_string(),
1019						format!("Reshape shape tensor cannot contain negative values except for -1 (contains {})", dim))),
1020						0..=i64::MAX => ()
1021					}
1022                }
1023
1024                let output_shape: Vec<i64> = shape_tensor_contents
1025                    .iter()
1026                    .enumerate()
1027                    .map(|(idx, dim)| {
1028                        if *dim == 0 && !allow_zero {
1029                            input_shapes[0].dim(idx) as i64
1030                        } else {
1031                            *dim
1032                        }
1033                    })
1034                    .collect();
1035
1036                if output_shape.iter().product::<i64>() != input_shapes[0].element_count() as i64 {
1037                    return Err(ShapeInferenceError::InvalidNode(
1038            			node.get_name().to_string(),
1039						format!("Reshape input tensor (element count={}) must have the same number of elements as specified by the new shape ({})", input_shapes[0].element_count(), output_shape.iter().product::<i64>())));
1040                }
1041
1042                Ok(vec![Shape::from(input_shapes[0].data_type, &output_shape)])
1043            } else {
1044                Err(ShapeInferenceError::Unsupported(format!(
1045                    "Reshape with dynamic shape tensor (input name is {shape_tensor_name})"
1046                )))
1047            }
1048        }
1049
1050        ("Concat", 1.., 1) => {
1051            let axis = node
1052                .get_attribute_value::<i64>("axis", None)
1053                .map_err(ShapeInferenceError::MissingAttribute)?;
1054
1055            // All input shapes must be the same except for the dimension at the specified axis
1056            let mut shape: Vec<i64> = input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1057            if axis < -(shape.len() as i64) || axis > (shape.len() - 1) as i64 {
1058                return Err(ShapeInferenceError::InvalidNode(
1059                    node.get_name().to_string(),
1060                    "axis attribute needs to be smaller than input tensor rank".to_string(),
1061                ));
1062            }
1063
1064            let axis_index = if axis < 0 {
1065                ((shape.len() as i64) + axis) as usize
1066            } else {
1067                axis as usize
1068            };
1069            shape[axis_index] = input_shapes.iter().map(|s| s.dim(axis_index) as i64).sum();
1070            Ok(vec![Shape::from(input_shapes[0].data_type, &shape)])
1071        }
1072
1073        ("Dropout", 1..=3, num_outputs @ 1..=2) => {
1074            let shape = input_shapes[0];
1075            Ok((0..num_outputs).map(|_| shape.clone()).collect())
1076        }
1077
1078        ("Unsqueeze", num_inputs @ 1..=2, 1) => {
1079            let axes: Vec<i64> = if num_inputs == 2 {
1080                let shape_tensor_name = &node.get_input()[1];
1081                if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1082                    // Get the tensor's contents
1083                    shape_tensor.get_int64_data().to_vec()
1084                } else {
1085                    return Err(ShapeInferenceError::Unsupported(
1086                        "Unsqueeze with dynamic axis inputs".to_string(),
1087                    ));
1088                }
1089            } else {
1090                node.get_attribute_value("axes", None)
1091                    .map_err(ShapeInferenceError::MissingAttribute)?
1092            };
1093
1094            let output_rank = input_shapes[0].rank() + axes.len();
1095            let mut input_shape: Vec<i64> =
1096                input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1097            for i in axes {
1098                let index = if i < 0 {
1099                    ((output_rank as i64) + i) as usize
1100                } else {
1101                    i as usize
1102                };
1103                input_shape.insert(index, 1);
1104            }
1105
1106            Ok(vec![Shape::from(input_shapes[0].data_type, &input_shape)])
1107        }
1108
1109        ("Range", 3, 1) => {
1110            // Currently only int64 ranges are supported
1111            let start = static_initializer_value_i64(initializers, &node.input[0])?;
1112            let end = static_initializer_value_i64(initializers, &node.input[1])?;
1113            let step = static_initializer_value_i64(initializers, &node.input[2])?;
1114
1115            if start.len() != 1 {
1116                return Err(ShapeInferenceError::InvalidNode(
1117                    node.get_name().to_string(),
1118                    format!(
1119                        "the start input needs to be a scalar, has {} elements",
1120                        start.len()
1121                    ),
1122                ));
1123            }
1124
1125            if end.len() != 1 {
1126                return Err(ShapeInferenceError::InvalidNode(
1127                    node.get_name().to_string(),
1128                    format!(
1129                        "the end input needs to be a scalar, has {} elements",
1130                        end.len()
1131                    ),
1132                ));
1133            }
1134
1135            if step.len() != 1 {
1136                return Err(ShapeInferenceError::InvalidNode(
1137                    node.get_name().to_string(),
1138                    format!(
1139                        "the step input needs to be a scalar, has {} elements",
1140                        step.len()
1141                    ),
1142                ));
1143            }
1144
1145            let element_count = (end[0] - start[0]) / step[0];
1146            Ok(vec![Shape::from(ScalarType::I64, &[element_count])])
1147        }
1148
1149        ("Squeeze", num_inputs @ 1..=2, 1) => {
1150            let has_axes = num_inputs == 2;
1151            let axes: Vec<i64> = if has_axes {
1152                let shape_tensor_name = &node.get_input()[1];
1153                if let Some(shape_tensor) = initializers.get(shape_tensor_name) {
1154                    // Get the tensor's contents
1155                    shape_tensor.get_int64_data().to_vec()
1156                } else {
1157                    return Err(ShapeInferenceError::Unsupported(
1158                        "Unsqueeze with dynamic axis inputs".to_string(),
1159                    ));
1160                }
1161            } else {
1162                vec![]
1163            };
1164
1165            let output_shape: Vec<i64> = input_shapes[0]
1166                .dims
1167                .iter()
1168                .enumerate()
1169                .flat_map(|(idx, dim)| {
1170                    if (has_axes && axes.contains(&(idx as i64))) || (!has_axes && *dim == 1) {
1171                        vec![]
1172                    } else {
1173                        vec![*dim as i64]
1174                    }
1175                })
1176                .collect();
1177
1178            Ok(vec![Shape::from(input_shapes[0].data_type, &output_shape)])
1179        }
1180
1181        ("Transpose", 1, 1) => {
1182            let input_dims: Vec<i64> = input_shapes[0].dims.iter().map(|x| *x as i64).collect();
1183            let output_dims: Vec<i64> = match node.get_attribute_value::<Vec<i64>>("perm", None) {
1184                Ok(perm) => perm.iter().map(|idx| input_dims[*idx as usize]).collect(),
1185                Err(_) => input_dims.iter().rev().cloned().collect(),
1186            };
1187            Ok(vec![Shape::from(input_shapes[0].data_type, &output_dims)])
1188        }
1189
1190        ("BatchNormalization", 1.., 1) => {
1191            // The first output's shape is equal to the input's shape
1192            Ok(vec![input_shapes[0].clone()])
1193        }
1194
1195        (
1196            "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
1197            | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1" | "ReduceProd",
1198            2,
1199            1,
1200        ) => Err(ShapeInferenceError::Unsupported(format!(
1201            "{} with two inputs (axes input not supported)",
1202            node.get_op_type()
1203        ))),
1204
1205        (
1206            "Sub" | "Pow" | "Add" | "Div" | "Mul" | "Identity" | "Sqrt" | "ReduceMean" | "Gather"
1207            | "Constant" | "Relu" | "LeakyRelu" | "MaxPool" | "Conv" | "AveragePool" | "Reshape"
1208            | "Concat" | "Unsqueeze" | "Cast" | "Squeeze" | "Shape" | "Slice" | "Range"
1209            | "ConstantOfShape" | "Transpose" | "Abs" | "Acos" | "Acosh" | "Asin" | "Sin" | "Asinh"
1210            | "Atan" | "Atanh" | "Cos" | "Cosh" | "Elu" | "Erf" | "Exp" | "Log" | "Neg" | "Ceil"
1211            | "Reciprocal" | "Floor" | "Mod" | "Celu" | "ReduceSum" | "ReduceMin" | "ReduceMax"
1212            | "ReduceSumSquare" | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
1213            | "ReduceProd" | "Size" | "Sign",
1214            _,
1215            _,
1216        ) => Err(ShapeInferenceError::InvalidNode(
1217            node.get_name().to_string(),
1218            format!(
1219                "invalid number of inputs ({}) or outputs ({})",
1220                node.get_input().len(),
1221                node.get_output().len()
1222            ),
1223        )),
1224
1225        (op_type, _inputs, _outputs) => {
1226            log::debug!("Shape inference unimplemented for op {op_type} with input shapes {input_shapes:#?}");
1227            Err(ShapeInferenceError::Unsupported(op_type.to_string()))
1228        }
1229    }
1230}
1231
1232/// https://github.com/onnx/onnx/blob/fb80e3ade84e9f406711aa41b9f3665753158371/onnx/defs/tensor/defs.cc#L814
1233fn process_slice_inputs(
1234    input_rank: i64,
1235    start: &mut i64,
1236    end: &mut i64,
1237    step: &mut i64,
1238) -> Result<(), ShapeInferenceError> {
1239    // process step
1240    if *step == 0 {
1241        return Err(ShapeInferenceError::InvalidNode(
1242            "".to_string(),
1243            "step value must not be zero for slice".to_string(),
1244        ));
1245    }
1246    // process start
1247    if *start < 0 {
1248        *start += input_rank;
1249    }
1250    if *step < 0 {
1251        *start = (*start).clamp(0, input_rank - 1);
1252    } else {
1253        *start = (*start).clamp(0, input_rank);
1254    }
1255
1256    // process end
1257    if *end < 0 {
1258        *end += input_rank;
1259    }
1260    if *step < 0 {
1261        *end = (*end).clamp(-1, input_rank - 1);
1262    } else {
1263        *end = (*end).clamp(0, input_rank);
1264    }
1265    Ok(())
1266}
1267
1268/// Some tensors only have the raw data field filled. This function moves that data to the respective fields (i.e. int64_data)
1269/// depending on the data type specified.
1270fn fix_raw_tensor(tensor: &mut TensorProto) -> Result<(), ShapeInferenceError> {
1271    if tensor.has_raw_data() {
1272        let raw_data = tensor.take_raw_data();
1273        match ScalarType::from_i32(tensor.get_data_type())
1274            .map_err(ShapeInferenceError::UnsupportedDataType)?
1275        {
1276            ScalarType::F32 => tensor.set_float_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1277            ScalarType::I64 => tensor.set_int64_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1278            ScalarType::I32 => tensor.set_int32_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1279            ScalarType::U8 => tensor.set_raw_data(bytemuck::cast_slice(&raw_data[..]).to_vec()),
1280        }
1281    }
1282    Ok(())
1283}
1284
1285#[cfg(test)]
1286mod tests {
1287    use std::collections::HashSet;
1288
1289    use protobuf::Message;
1290    use wonnx::onnx::ModelProto;
1291
1292    use crate::shape_inference::infer_shapes;
1293
1294    use super::dimensions_infos;
1295
1296    /// Load a model, strip (and stash) all shape info for intermediate values, then re-infer shapes and compare with stashed original
1297    async fn test_shape_inference_for_model(path: &str, should_fold_constants: bool) {
1298        let mut model =
1299            ModelProto::parse_from_bytes(&std::fs::read(path).expect("ONNX Model path not found."))
1300                .unwrap();
1301
1302        let graph = model.mut_graph();
1303        let infos = dimensions_infos(graph).unwrap();
1304        graph.value_info.clear();
1305        infer_shapes(graph, should_fold_constants, 13)
1306            .await
1307            .unwrap();
1308        let new_infos = dimensions_infos(graph).unwrap();
1309
1310        let keys_in_old: HashSet<String> = infos.keys().cloned().collect();
1311        let keys_in_new: HashSet<String> = new_infos.keys().cloned().collect();
1312        let all_keys: HashSet<String> = keys_in_old.union(&keys_in_new).cloned().collect();
1313
1314        for key in all_keys {
1315            if !keys_in_old.contains(&key) || !keys_in_new.contains(&key) || infos[&key].is_empty()
1316            {
1317                // Key is new after shape inference (the original model was apparently missing shapes)
1318                // Key missing after shape inference (this may be the result of constant folding)
1319                // Empty dims in source means missing... ignore
1320            } else {
1321                assert_eq!(
1322                    infos[&key], new_infos[&key],
1323                    "different shape inferred for {key}"
1324                )
1325            }
1326        }
1327    }
1328
1329    #[test]
1330    fn test_shape_inference() {
1331        let _ = env_logger::builder().is_test(true).try_init();
1332
1333        pollster::block_on(async {
1334            test_shape_inference_for_model("../data/models/opt-mnist.onnx", false).await;
1335            test_shape_inference_for_model("../data/models/opt-squeeze.onnx", false).await;
1336            test_shape_inference_for_model("../data/models/single_relu.onnx", false).await;
1337            test_shape_inference_for_model("../data/models/single_relu.onnx", false).await;
1338            test_shape_inference_for_model("../data/models/mobilenetv2-7.onnx", true).await;
1339        });
1340    }
1341}