wonnx/
optimizer.rs

1//! Optimizer that walks the DAG and transforms or coalesces ops for quicker execution
2use crate::{
3    gpu::GpuModel,
4    ir::{Input, Node, NodeDefinition, NodeIdentifier, OperatorDefinition},
5    onnx::{NodeProto, TensorProto},
6    resource::{padding, request_device_queue},
7    utils::{
8        attribute, AttributeNotFoundError, DataTypeError, NodeAttributes, OutputTensor, ScalarType,
9        Shape,
10    },
11    GpuError,
12};
13use async_recursion::async_recursion;
14use bytemuck::pod_collect_to_vec;
15use protobuf::RepeatedField;
16use std::{
17    borrow::Cow,
18    collections::{HashMap, VecDeque},
19    sync::Arc,
20};
21use thiserror::Error;
22
23#[derive(Debug, Error)]
24pub enum OptimizerError {
25    #[error("node has no inputs")]
26    NoInputs,
27
28    #[error("unsupported: {0}")]
29    Unsupported(String),
30
31    #[error("invalid data type {data_type:?} for input {input} of op {op}")]
32    InvalidInputDataType {
33        data_type: ScalarType,
34        input: String,
35        op: String,
36    },
37
38    #[error("error with data type: {0}")]
39    InvalidDataType(#[from] DataTypeError),
40
41    #[error("node is invalid: {0}")]
42    InvalidNode(String),
43
44    #[error("required attribute not found: {0}")]
45    AttributeNotFound(#[from] AttributeNotFoundError),
46
47    #[error("error during constant folding: {0}")]
48    ConstantFoldingError(#[from] GpuError),
49}
50
51pub struct Optimizer<'model> {
52    padded_tensors: HashMap<String, Arc<Node<'model>>>,
53    optimized: HashMap<NodeIdentifier<'model>, Arc<Node<'model>>>,
54    onnx_opset_version: i64,
55}
56
57impl<'model> Optimizer<'model> {
58    pub fn new(onnx_opset_version: i64) -> Self {
59        Self {
60            padded_tensors: HashMap::new(),
61            optimized: HashMap::new(),
62            onnx_opset_version,
63        }
64    }
65
66    // Calculates the output of a constant node, then returns a node that contains the result as initializer
67    async fn fold_constant_node(
68        &self,
69        node: Arc<Node<'model>>,
70    ) -> Result<Option<Arc<Node<'model>>>, OptimizerError> {
71        assert!(node.is_constant());
72
73        match node.definition() {
74            NodeDefinition::Operator(op_def) => {
75                // TODO: constant nodes with multiple outputs
76                if op_def.proto.output.len() != 1 {
77                    log::warn!(
78                        "node {:?} is constant, but has multiple outputs, which we can't fold yet",
79                        node.definition()
80                    );
81                    return Ok(None);
82                }
83
84                match op_def.proto.get_op_type() {
85                    "Constant" => Ok(Some(Arc::new(Node {
86                        definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
87                            Self::constant_node_to_tensor(node)?,
88                        ))),
89                        inputs: vec![],
90                    }))),
91                    _ => self.infer_constant_node_to_tensor(node.clone()).await,
92                }
93            }
94            NodeDefinition::Tensor(_) => Ok(None), // already constantized
95            NodeDefinition::Input(_) | NodeDefinition::Missing => unreachable!(),
96            NodeDefinition::Outputs { .. } => Ok(None), // all the outputs themselves are already constant, so nothing to do
97        }
98    }
99
100    // Takes a node with operator type 'Shape' and returns its output as a tensor
101    fn shape_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
102        let NodeDefinition::Operator(op_def) = node.definition() else {
103            panic!("node must be a Shape node");
104        };
105        assert_eq!(op_def.proto.get_op_type(), "Shape");
106
107        if node.inputs.len() != 1 {
108            return Err(OptimizerError::InvalidNode(format!(
109                "Shape node should only have one input, has {}",
110                node.inputs.len()
111            )));
112        }
113
114        // Determine the shape of the input
115        let input = &node.inputs[0];
116        let in_node = &input.source_node.definition;
117        let in_shape = match in_node {
118            NodeDefinition::Input(input) => input.get_shape()?,
119            NodeDefinition::Operator(input_op_def) => {
120                input_op_def.output_shapes[input.output_index].clone()
121            }
122            NodeDefinition::Tensor(input_tensor) => Shape::from(
123                ScalarType::from_i32(input_tensor.get_data_type())
124                    .map_err(OptimizerError::InvalidDataType)?,
125                input_tensor.get_dims(),
126            ),
127            NodeDefinition::Outputs { .. } => {
128                return Err(OptimizerError::Unsupported(
129                    "output node cannot be used as an input to Shape node".to_string(),
130                ))
131            }
132            NodeDefinition::Missing => {
133                return Err(OptimizerError::InvalidNode(
134                    "Shape node has missing input".to_string(),
135                ))
136            }
137        };
138        let rank = in_shape.rank() as i64;
139        let mut start: i64 = op_def.proto.get_attribute_value("start", Some(0)).unwrap();
140        let mut end: i64 = op_def.proto.get_attribute_value("end", Some(rank)).unwrap();
141        if start < 0 {
142            start += rank;
143        }
144        if end < 0 {
145            end += rank;
146        }
147        start = start.clamp(0, rank);
148        end = end.clamp(0, rank);
149
150        if start < 0 || start > rank {
151            return Err(OptimizerError::InvalidNode(format!(
152                "start index of Shape node cannot be below zero, found {start}"
153            )));
154        }
155
156        if end < 0 || end > rank || end < start {
157            return Err(OptimizerError::InvalidNode(format!(
158                "end index of Shape node cannot be below zero or higher than {rank} or below start {start}, found {end}"
159            )));
160        }
161
162        let values: Vec<i64> = in_shape.dims[(start as usize)..=((end - 1) as usize)]
163            .iter()
164            .map(|x| *x as i64)
165            .collect();
166        let dims = vec![values.len() as i64];
167        Ok(TensorProto::from(OutputTensor::I64(values), dims))
168    }
169
170    // Takes a node with operator type 'Constant' and returns its output as a tensor
171    fn constant_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
172        let NodeDefinition::Operator(op_def) = node.definition() else {
173            panic!("node must be a Constant node");
174        };
175        assert_eq!(op_def.proto.get_op_type(), "Constant");
176        let proto = &op_def.proto;
177        let output_name = proto.output.get(0).unwrap().to_owned();
178
179        let mut tp: TensorProto =
180            if let Ok(values) = proto.get_attribute_value::<Vec<f32>>("value_floats", None) {
181                let dims = vec![values.len() as i64];
182                TensorProto::from(OutputTensor::F32(values), dims)
183            } else if let Ok(values) = proto.get_attribute_value::<Vec<i64>>("value_ints", None) {
184                let dims = vec![values.len() as i64];
185                TensorProto::from(OutputTensor::I64(values), dims)
186            } else if let Ok(value) = proto.get_attribute_value::<f32>("value_float", None) {
187                TensorProto::from(OutputTensor::F32(vec![value]), vec![1])
188            } else if let Ok(value) = proto.get_attribute_value::<i64>("value_int", None) {
189                TensorProto::from(OutputTensor::I64(vec![value]), vec![1])
190            } else if let Ok(tp) = proto.get_attribute_value::<TensorProto>("value", None) {
191                tp
192            } else {
193                return Err(OptimizerError::Unsupported(
194                    "Constant node with unknown value type".to_string(),
195                ));
196            };
197
198        tp.set_name(output_name);
199        Ok(tp)
200    }
201
202    // Takes a node with operator type 'Size' and returns its output as a tensor
203    fn size_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
204        let NodeDefinition::Operator(op_def) = node.definition() else {
205            panic!("node must be a Size node");
206        };
207        assert_eq!(op_def.proto.get_op_type(), "Size");
208
209        if node.inputs.len() != 1 {
210            return Err(OptimizerError::InvalidNode(format!(
211                "Size node should only have one input, has {}",
212                node.inputs.len()
213            )));
214        }
215
216        // Determine the shape of the input
217        let input = &node.inputs[0];
218        let in_node = &input.source_node.definition;
219        let in_element_count: i64 = match in_node {
220            NodeDefinition::Input(input) => input.get_shape()?.element_count() as i64,
221            NodeDefinition::Operator(input_op_def) => {
222                input_op_def.output_shapes[input.output_index].element_count() as i64
223            }
224            NodeDefinition::Tensor(input_tensor) => input_tensor.get_dims().iter().product(),
225            NodeDefinition::Outputs { .. } => {
226                return Err(OptimizerError::Unsupported(
227                    "output node cannot be used as an input to Shape node".to_string(),
228                ))
229            }
230            NodeDefinition::Missing => {
231                return Err(OptimizerError::InvalidNode(
232                    "Shape node has missing input".to_string(),
233                ))
234            }
235        };
236
237        Ok(TensorProto::from(
238            OutputTensor::I64(vec![in_element_count]),
239            vec![1],
240        ))
241    }
242
243    // Infers the output for a constant node (must be a constant and operator node, or the function panics)
244    async fn infer_constant_node_to_tensor(
245        &self,
246        node: Arc<Node<'model>>,
247    ) -> Result<Option<Arc<Node<'model>>>, OptimizerError> {
248        assert!(node.is_constant());
249
250        // Create an output node so we can perform inference for this node
251        if let NodeDefinition::Operator(op_def) = node.definition() {
252            let output_name = op_def.proto.output.get(0).unwrap().to_owned();
253
254            let out_node = Arc::new(Node {
255                definition: NodeDefinition::Outputs {
256                    names: vec!["output".to_string()],
257                },
258                inputs: vec![Input {
259                    source_node: node.clone(),
260                    output_index: 0,
261                }],
262            });
263
264            // Perform inference
265            let (device, queue) = request_device_queue().await;
266            let gm = GpuModel::from(out_node, device, queue, self.onnx_opset_version)
267                .map_err(OptimizerError::ConstantFoldingError)?;
268            let mut outputs = gm.infer(&HashMap::new()).await?;
269
270            // Take the output tensor and make it into an initializer node
271            let (_, output_tensor) = outputs.drain().take(1).next().unwrap();
272            log::info!("folded {output_name} to {output_tensor:?}");
273            let mut output_tensor_proto = TensorProto::from(
274                output_tensor,
275                op_def.output_shapes[0]
276                    .dims
277                    .iter()
278                    .map(|x| *x as i64)
279                    .collect(),
280            );
281            output_tensor_proto.set_name(output_name);
282
283            let tensor_node = Node {
284                definition: NodeDefinition::Tensor(Box::new(Cow::Owned(output_tensor_proto))),
285                inputs: vec![],
286            };
287
288            Ok(Some(Arc::new(tensor_node)))
289        } else {
290            panic!("node to fold must be operator")
291        }
292    }
293
294    /// Optimize a branch of a graph (memoized)
295    #[async_recursion]
296    pub async fn optimize(
297        &mut self,
298        node: Arc<Node<'model>>,
299    ) -> Result<Arc<Node<'model>>, OptimizerError> {
300        let identifier = node.identifier();
301        match self.optimized.get(&identifier) {
302            Some(opt_node) => Ok(opt_node.clone()),
303            None => {
304                let opt_node = self.optimize_actual(node).await?;
305                self.optimized.insert(identifier, opt_node.clone());
306                Ok(opt_node)
307            }
308        }
309    }
310
311    /// Optimize a branch of a graph. Takes a node an attempts to form a chain of nodes with single (dynamic) inputs by
312    /// traversing towards the inputs.
313    #[async_recursion]
314    async fn optimize_actual(
315        &mut self,
316        node: Arc<Node<'model>>,
317    ) -> Result<Arc<Node<'model>>, OptimizerError> {
318        // Try to form a chain of nodes that have one dynamic input
319        let prior;
320        let mut chain = VecDeque::new();
321        chain.push_back(node.clone());
322
323        loop {
324            let head = chain.front().unwrap();
325            let dynamic_inputs = head
326                .inputs
327                .iter()
328                .filter(|input| input.source_node.is_dynamic() && input.output_index == 0)
329                .collect::<Vec<&Input>>();
330
331            if dynamic_inputs.len() != 1 {
332                prior = chain.pop_front().unwrap();
333                break;
334            }
335            chain.push_front(dynamic_inputs[0].source_node.clone());
336        }
337
338        log::debug!(
339            "optimize: node={:?} def={:?} chain={}, next={:?}",
340            node.identifier(),
341            node.definition,
342            chain
343                .iter()
344                .map(|x| format!("[{:?}]", x.definition))
345                .collect::<Vec<String>>()
346                .join(" -> "),
347            prior.identifier()
348        );
349
350        // Try to simplify this chain of nodes
351        if chain.len() > 1 {
352            let mut final_chain: Vec<Arc<Node>> = vec![];
353            while !chain.is_empty() {
354                log::debug!("optimize chain {}", chain.len());
355                while self.optimize_chain(&mut chain)? {
356                    log::debug!("optimize chain succeeded {}", chain.len());
357                }
358
359                if !chain.is_empty() {
360                    // Now pop off the first item and make it final
361                    let first = chain.pop_front().unwrap();
362                    final_chain.push(first);
363                }
364
365                log::debug!(
366                    "optimized chain: {}",
367                    final_chain
368                        .iter()
369                        .map(|x| format!("[{:?}]", x.definition))
370                        .collect::<Vec<String>>()
371                        .join(" -> ")
372                );
373            }
374            drop(chain);
375
376            // optimize next node
377            let optimized_next = self.optimize(prior).await?;
378
379            if final_chain.is_empty() {
380                return Ok(optimized_next);
381            }
382
383            // Fix up the connections between these nodes
384            for node_index in 0..=(final_chain.len() - 1) {
385                let consumer = final_chain[node_index].clone();
386                let producer = if node_index == 0 {
387                    optimized_next.clone()
388                } else {
389                    final_chain[node_index - 1].clone()
390                };
391                final_chain[node_index] = self
392                    .locally_optimized_node_with(
393                        consumer.clone(),
394                        consumer
395                            .inputs
396                            .iter()
397                            .map(|old_input| {
398                                // Each node is guaranteed to have only one 'dynamic' input. This is the one we will replace
399                                let is_dynamic_source = old_input.source_node.is_dynamic()
400                                    && old_input.output_index == 0;
401                                if is_dynamic_source {
402                                    Input {
403                                        source_node: producer.clone(),
404                                        output_index: 0,
405                                    }
406                                } else {
407                                    old_input.clone()
408                                }
409                            })
410                            .collect(),
411                    )
412                    .await?;
413            }
414
415            Ok(final_chain.last().unwrap().clone())
416        } else {
417            // Just optimize this nodes' inputs recursively
418            let mut new_inputs = Vec::with_capacity(node.inputs.len());
419            for input in node.inputs.iter() {
420                new_inputs.push(Input {
421                    source_node: self.optimize(input.source_node.clone()).await?,
422                    output_index: input.output_index,
423                });
424            }
425            self.locally_optimized_node_with(node.clone(), new_inputs)
426                .await
427        }
428    }
429
430    /// Create a new node from an existing definition, applying optimizations local to a single node
431    async fn locally_optimized_node_with(
432        &mut self,
433        node: Arc<Node<'model>>,
434        mut new_inputs: Vec<Input<'model>>,
435    ) -> Result<Arc<Node<'model>>, OptimizerError> {
436        log::debug!(
437            "locally_optimized_node_with {:?} {:?}",
438            node.identifier(),
439            node.definition()
440        );
441
442        // Fold Shape/Size nodes (not considered constant but we can still fold it)
443        if let NodeDefinition::Operator(op_def) = &node.definition {
444            match op_def.proto.get_op_type() {
445                "Shape" => {
446                    return Ok(Arc::new(Node {
447                        definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
448                            Self::shape_node_to_tensor(node)?,
449                        ))),
450                        inputs: vec![],
451                    }))
452                }
453                "Size" => {
454                    return Ok(Arc::new(Node {
455                        definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
456                            Self::size_node_to_tensor(node)?,
457                        ))),
458                        inputs: vec![],
459                    }))
460                }
461                _ => {}
462            }
463        }
464
465        // Fold constant nodes
466        if node.is_constant() && !matches!(node.definition, NodeDefinition::Missing) {
467            log::debug!(
468                "node is constant: {:?} {:?}",
469                node.identifier(),
470                node.definition()
471            );
472            if let Some(const_node) = self.fold_constant_node(node.clone()).await? {
473                return Ok(const_node);
474            }
475        }
476
477        match &node.definition {
478            NodeDefinition::Operator(op_def) => {
479                match op_def.proto.get_op_type() {
480                    "Conv" | "ConvRelu" | "ConvLeakyRelu" => {
481                        // This optimization inserts some padding to convolution between kernels with kernel 3x3, because of
482                        // the stride of matrix3x3 is 16 in wgsl. It makes the computation matrixable and increases the performance.
483                        if new_inputs.len() > 2
484                            && op_def
485                                .proto
486                                .get_attribute_value::<Vec<i64>>("kernel_shape", None)?
487                                == [3, 3]
488                            && (op_def
489                                .proto
490                                .get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?
491                                == [1, 1, 1, 1]
492                                || op_def.proto.get_attribute_value(
493                                    "auto_pad",
494                                    Some("SAME_UPPER".to_string()),
495                                )? == "SAME_UPPER")
496                            && op_def
497                                .proto
498                                .get_attribute_value("strides", Some(vec![1, 1]))?
499                                == [1, 1]
500                            && op_def.proto.get_attribute_value("group", Some(1))? == 1
501                            && op_def.output_shapes[0].dim(1) % 4 == 0
502                        {
503                            if let NodeDefinition::Tensor(tensor) =
504                                &new_inputs[1].source_node.definition
505                            {
506                                new_inputs[1] = Input {
507                                    output_index: 0,
508                                    source_node: match self.padded_tensors.get(tensor.get_name()) {
509                                        Some(padded_tensor_node) => padded_tensor_node.clone(),
510                                        None => {
511                                            let data = tensor.get_float_data();
512                                            let raw_data = if !data.is_empty() {
513                                                bytemuck::cast_slice(data)
514                                            } else {
515                                                tensor.get_raw_data()
516                                            };
517
518                                            let padded_raw_data = padding(raw_data, 12, 4);
519
520                                            log::info!(
521                                                "applying padding optimization to tensor {}: strides data is {} bytes before, {} bytes after",
522                                                tensor.get_name(),
523                                                raw_data.len(),
524                                                padded_raw_data.len()
525                                            );
526
527                                            // Create a new tensor with the padded data
528                                            let mut new_tensor = tensor.clone().into_owned();
529                                            new_tensor.set_float_data(vec![]);
530                                            new_tensor.set_raw_data(padded_raw_data);
531                                            let new_node = Arc::new(Node {
532                                                definition: NodeDefinition::Tensor(Box::new(
533                                                    Cow::Owned(new_tensor),
534                                                )),
535                                                inputs: vec![],
536                                            });
537                                            self.padded_tensors.insert(
538                                                tensor.get_name().to_string(),
539                                                new_node.clone(),
540                                            );
541                                            new_node
542                                        }
543                                    },
544                                }
545                            }
546                        }
547
548                        let new_node = Node {
549                            inputs: new_inputs,
550                            definition: NodeDefinition::Operator(op_def.clone()),
551                        };
552
553                        Ok(Arc::new(new_node))
554                    }
555
556                    // The Clip, Split, Resize, Reshape and Reduce* operators each take optional inputs that influence the operation.
557                    // These are typically statically initialized tensors containing shapes. For more efficient execution we
558                    // move these static values to attributes.
559                    op @ ("Clip" | "Pad" | "Split" | "Resize" | "Reshape" | "ReduceMean"
560                    | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
561                    | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
562                    | "ReduceProd") => {
563                        if new_inputs.is_empty() {
564                            return Err(OptimizerError::NoInputs);
565                        }
566
567                        // Names of the inputs (see ONNX operator spec)
568                        let attr_names = match op {
569                            "Split" => SPLIT_INPUT_NAMES,
570                            "Resize" => RESIZE_INPUT_NAMES,
571                            "Reshape" => RESHAPE_INPUT_NAMES,
572                            "Clip" => CLIP_INPUT_NAMES,
573                            "Pad" => PAD_INPUT_NAMES,
574                            "ReduceSum" => REDUCE_OPS_INPUT_NAMES,
575                            "ReduceL1" => REDUCE_OPS_INPUT_NAMES,
576                            "ReduceL2" => REDUCE_OPS_INPUT_NAMES,
577                            "ReduceLogSum" => REDUCE_OPS_INPUT_NAMES,
578                            "ReduceLogSumExp" => REDUCE_OPS_INPUT_NAMES,
579                            "ReduceMax" => REDUCE_OPS_INPUT_NAMES,
580                            "ReduceMean" => REDUCE_OPS_INPUT_NAMES,
581                            "ReduceMin" => REDUCE_OPS_INPUT_NAMES,
582                            "ReduceProd" => REDUCE_OPS_INPUT_NAMES,
583                            "ReduceSumSquare" => REDUCE_OPS_INPUT_NAMES,
584                            _ => unreachable!(),
585                        };
586
587                        // Make a new copy of the attributes list (we're going to add attributes)
588                        let mut new_proto = op_def.proto.clone().into_owned();
589                        let mut attributes = op_def.proto.get_attribute().to_vec();
590
591                        // Loop over the inputs (skipping the first one - that's going to be the data input)
592                        for input_index in 1..(new_inputs.len().min(attr_names.len())) {
593                            let source_node = &new_inputs[input_index].source_node;
594                            match &source_node.definition {
595                                // If the input is an initializer (Tensor) we can obtain the data from the definition and move it to an attribute
596                                NodeDefinition::Tensor(tensor_proto) => {
597                                    let attr_name = attr_names[input_index];
598                                    let data_type =
599                                        ScalarType::from_i32(tensor_proto.get_data_type())?;
600
601                                    match (op, attr_name) {
602                                        ("Split", "split")
603                                        | ("Resize", "roi")
604                                        | ("Resize", "sizes")
605                                        | ("Reshape", "shape")
606                                        | (
607                                            "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax"
608                                            | "ReduceSumSquare" | "ReduceLogSumExp"
609                                            | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
610                                            | "ReduceProd",
611                                            "axes",
612                                        )
613                                        | ("Pad", "pads")
614                                        | ("Resize", "scales")
615                                        | ("Clip", "min" | "max") => match data_type {
616                                            ScalarType::F32 => {
617                                                let value: Vec<f32> = if tensor_proto
618                                                    .get_float_data()
619                                                    .is_empty()
620                                                {
621                                                    pod_collect_to_vec(tensor_proto.get_raw_data())
622                                                } else {
623                                                    tensor_proto.get_float_data().to_vec()
624                                                };
625                                                log::info!(
626                                                    "transferring input {} for op {} to f32 attribute (initializer data type: {:?}): {:?}",
627                                                    attr_name,
628                                                    op,
629                                                    data_type,
630                                                    value,
631                                                );
632                                                attributes.push(attribute(
633                                                    attr_names[input_index],
634                                                    value,
635                                                ));
636                                            }
637                                            ScalarType::I64 => {
638                                                let value = if tensor_proto
639                                                    .get_int64_data()
640                                                    .is_empty()
641                                                {
642                                                    pod_collect_to_vec(tensor_proto.get_raw_data())
643                                                } else {
644                                                    tensor_proto.get_int64_data().to_vec()
645                                                };
646                                                log::info!(
647                                                    "transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}",
648                                                    attr_name,
649                                                    op,
650                                                    data_type,
651                                                    value,
652                                                );
653                                                attributes.push(attribute(
654                                                    attr_names[input_index],
655                                                    value,
656                                                ));
657                                            }
658                                            _ => {
659                                                return Err(OptimizerError::InvalidInputDataType {
660                                                    data_type,
661                                                    input: attr_name.to_string(),
662                                                    op: op.to_string(),
663                                                })
664                                            }
665                                        },
666                                        _ => {
667                                            // Some other unspecified input that we do not support yet
668                                            return Err(OptimizerError::Unsupported(format!(
669                                                "data_type {} for input {} to op {}",
670                                                tensor_proto.get_data_type(),
671                                                attr_name,
672                                                op
673                                            )));
674                                        }
675                                    }
676                                }
677                                NodeDefinition::Missing => {
678                                    // Just remove it
679                                }
680                                _ => {
681                                    // One of the inputs (except the first) is something other than a tensor (e.g. 'dynamic')
682                                    return Err(OptimizerError::Unsupported(format!(
683                                        "{} operation with dynamic input for {}",
684                                        op, attr_names[input_index]
685                                    )));
686                                }
687                            }
688                        }
689
690                        // Create new node with extra attributes
691                        new_proto.set_attribute(RepeatedField::from(attributes));
692
693                        let new_node = Node {
694                            inputs: vec![new_inputs[0].clone()],
695                            definition: NodeDefinition::Operator(Box::new(OperatorDefinition {
696                                proto: Cow::Owned(new_proto),
697                                output_shapes: op_def.output_shapes.clone(),
698                            })),
699                        };
700
701                        Ok(Arc::new(new_node))
702                    }
703
704                    _ => Ok(Arc::new(Node {
705                        inputs: new_inputs,
706                        definition: NodeDefinition::Operator(op_def.clone()),
707                    })),
708                }
709            }
710            NodeDefinition::Tensor(..) | NodeDefinition::Input(..) => {
711                assert!(
712                    new_inputs.is_empty(),
713                    "non-operator node cannot have inputs"
714                );
715                // No need to do anything with the provided new inputs
716                Ok(node.clone())
717            }
718            &NodeDefinition::Outputs { .. } => Ok(Arc::new(Node {
719                inputs: new_inputs,
720                definition: node.definition().clone(),
721            })),
722            NodeDefinition::Missing => Ok(node.clone()),
723        }
724    }
725
726    /// Attempt to fuse several operators in a chain of operators with no other dynamic inputs. The function receives a list
727    /// of nodes that are guaranteed to be operators that each have one input (exactly). It is free to remove or add nodes
728    /// to this list. The caller will fix up the input/output relationships between the nodes.
729    fn optimize_chain(
730        &mut self,
731        chain: &mut VecDeque<Arc<Node<'model>>>,
732    ) -> Result<bool, OptimizerError> {
733        // Start by throwing out all Identity nodes
734        chain.retain(|n| match &n.definition {
735            NodeDefinition::Operator(op_def) => op_def.proto.get_op_type() != "Identity",
736            _ => true,
737        });
738
739        let names: Vec<&str> = chain
740            .iter()
741            .map(|x| match &x.definition {
742                NodeDefinition::Operator(op_def) => op_def.proto.get_op_type(),
743                _ => "",
744            })
745            .collect();
746
747        log::debug!("optimize_chain {:?}", names);
748
749        match &names[..] {
750            // Double Neg: just cull
751            ["Neg", "Neg", ..] => {
752                chain.pop_front();
753                chain.pop_front();
754                Ok(true)
755            }
756
757            // Conv+Relu or Conv+LeakyRelu: combine into ConvRelu/ConvLeakyRelu
758            ["Conv", "Relu", ..] | ["Conv", "LeakyRelu", ..] => {
759                let conv = chain[0].clone();
760                let relu = chain[1].clone();
761
762                if let (NodeDefinition::Operator(conv_def), NodeDefinition::Operator(relu_def)) =
763                    (&conv.definition, &relu.definition)
764                {
765                    // Use the Conv node as template for the new fused Conv[Leaky]Relu node
766                    let mut convrelu_def = *conv_def.clone();
767                    let mut convrelu_proto = conv_def.proto.clone().into_owned();
768                    let new_op_type = match relu_def.proto.get_op_type() {
769                        "LeakyRelu" => "ConvLeakyRelu",
770                        "Relu" => "ConvRelu",
771                        _ => unreachable!(),
772                    };
773                    convrelu_proto.set_op_type(new_op_type.to_string());
774
775                    // Copy all Relu attributes over to the copy of the Conv node
776                    let mut attributes = conv_def.proto.get_attribute().to_vec();
777                    attributes.extend(relu_def.proto.get_attribute().iter().cloned());
778                    convrelu_proto.set_attribute(RepeatedField::from(attributes));
779                    convrelu_proto.set_name(format!(
780                        "{}+{}",
781                        conv.definition.get_name(),
782                        relu.definition.get_name()
783                    ));
784
785                    log::debug!(
786                        "can fuse chain of Conv/[Leaky]Relu to Conv[Leaky]Relu: {:?}: {:?} + {:?} = {}",
787                        names,
788                        conv.definition(),
789                        relu.definition(),
790                        convrelu_proto.get_name()
791                    );
792
793                    convrelu_def.proto = Cow::Owned(convrelu_proto);
794
795                    let node = Arc::new(Node {
796                        inputs: conv.inputs.clone(),
797                        definition: NodeDefinition::Operator(Box::new(convrelu_def)),
798                    });
799
800                    chain.remove(0);
801                    chain.remove(0);
802                    chain.insert(0, node);
803                    Ok(true)
804                } else {
805                    unreachable!();
806                }
807            }
808            _ => Ok(false),
809        }
810    }
811}
812
813// Names associated with the inputs of the Split, Resize, Reshape and Clip operators (in positional order - see ONNX spec)
814static SPLIT_INPUT_NAMES: &[&str] = &["input", "split"];
815static RESIZE_INPUT_NAMES: &[&str] = &["X", "roi", "scales", "sizes"];
816static RESHAPE_INPUT_NAMES: &[&str] = &["data", "shape"];
817static CLIP_INPUT_NAMES: &[&str] = &["input", "min", "max"];
818static REDUCE_OPS_INPUT_NAMES: &[&str] = &["input", "axes"];
819static PAD_INPUT_NAMES: &[&str] = &["data", "pads", "constant_value"];
820
821/// Generate the output for a ConstantOfShape node
822pub fn constant_of_shape_output(
823    node: &NodeProto,
824    element_count: usize,
825) -> Result<OutputTensor, OptimizerError> {
826    if let Ok(constant_value_tensor) = node.get_attribute_value::<TensorProto>("value", None) {
827        match ScalarType::from_i32(constant_value_tensor.get_data_type()).map_err(|_| {
828            OptimizerError::Unsupported(format!(
829                "unsupported data type {}",
830                constant_value_tensor.get_data_type()
831            ))
832        })? {
833            ScalarType::F32 => {
834                let fd = constant_value_tensor.get_float_data();
835                if fd.is_empty() {
836                    return Err(OptimizerError::InvalidNode(
837                        "value tensor for ConstantOfShape is empty".to_string(),
838                    ));
839                }
840                Ok(OutputTensor::F32(vec![fd[0]; element_count]))
841            }
842            ScalarType::I64 => {
843                let fd = constant_value_tensor.get_int64_data();
844                if fd.is_empty() {
845                    return Err(OptimizerError::InvalidNode(
846                        "value tensor for ConstantOfShape is empty".to_string(),
847                    ));
848                }
849                Ok(OutputTensor::I64(vec![fd[0]; element_count]))
850            }
851            ScalarType::I32 => {
852                let fd = constant_value_tensor.get_int32_data();
853                if fd.is_empty() {
854                    return Err(OptimizerError::InvalidNode(
855                        "value tensor for ConstantOfShape is empty".to_string(),
856                    ));
857                }
858                Ok(OutputTensor::I32(vec![fd[0]; element_count]))
859            }
860            ScalarType::U8 => {
861                let fd = constant_value_tensor.get_raw_data();
862                if fd.is_empty() {
863                    return Err(OptimizerError::InvalidNode(
864                        "value tensor for ConstantOfShape is empty".to_string(),
865                    ));
866                }
867                Ok(OutputTensor::U8(vec![fd[0]; element_count]))
868            }
869        }
870    } else {
871        // The default value is a zero f32
872        Ok(OutputTensor::F32(vec![0.0; element_count]))
873    }
874}
875
876#[cfg(test)]
877mod test {
878    use std::sync::Arc;
879
880    use crate::{
881        ir::{self, Node, NodeDefinition},
882        onnx::AttributeProto,
883        utils::{attribute, graph, initializer, model, node, tensor},
884    };
885
886    use super::Optimizer;
887
888    fn friendly_name(node: Arc<Node>) -> String {
889        match node.definition() {
890            NodeDefinition::Outputs { .. } => String::from("<outputs>"),
891            NodeDefinition::Missing => String::from("<missing>"),
892            NodeDefinition::Operator(op_def) => {
893                format!("{}_{}", op_def.proto.get_op_type(), op_def.proto.get_name())
894            }
895            d => format!("{}", d.get_name()),
896        }
897    }
898
899    fn traverse(node: Arc<Node>, pairs: &mut Vec<(String, String)>) {
900        let my_name = friendly_name(node.clone());
901        for input in &node.inputs {
902            let source_node_name = friendly_name(input.source_node.clone());
903            pairs.push((source_node_name, my_name.to_string()))
904        }
905
906        for input in &node.inputs {
907            traverse(input.source_node.clone(), pairs);
908        }
909    }
910
911    // Test: X -> [Identity] A -> [Identity] -> Y => X -> Y
912    #[test]
913    pub fn test_optimize_identity_identity() {
914        let _ = env_logger::builder().is_test(true).try_init();
915        pollster::block_on(async {
916            let m = model(graph(
917                vec![tensor("X", &[1])],
918                vec![tensor("Y", &[1])],
919                vec![tensor("A", &[1])],
920                vec![],
921                vec![
922                    node(vec!["X"], vec!["A"], "a", "Identity", vec![]),
923                    node(vec!["A"], vec!["Y"], "b", "Identity", vec![]),
924                ],
925            ));
926
927            let root = ir::Node::from_model(&m, None).unwrap();
928            let mut opt = Optimizer::new(13);
929            let new_root = opt.optimize(root).await.unwrap();
930            let mut new_pairs = vec![];
931            traverse(new_root, &mut new_pairs);
932            assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string())]);
933        })
934    }
935
936    // Test: X -> [Neg] A -> [Neg] -> Y => X -> Y
937    #[test]
938    pub fn test_optimize_neg_neg() {
939        let _ = env_logger::builder().is_test(true).try_init();
940        pollster::block_on(async {
941            let m = model(graph(
942                vec![tensor("X", &[1])],
943                vec![tensor("Y", &[1])],
944                vec![tensor("A", &[1])],
945                vec![],
946                vec![
947                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
948                    node(vec!["A"], vec!["Y"], "b", "Neg", vec![]),
949                ],
950            ));
951
952            let root = ir::Node::from_model(&m, None).unwrap();
953            let mut opt = Optimizer::new(13);
954            let new_root = opt.optimize(root).await.unwrap();
955            let mut new_pairs = vec![];
956            traverse(new_root, &mut new_pairs);
957            assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string())]);
958        });
959    }
960
961    // Test: X -> [Neg] A -> [Neg] B -> [Neg] -> Y => X -> Identity -> Y
962    #[test]
963    pub fn test_optimize_3neg() {
964        pollster::block_on(async {
965            let _ = env_logger::builder().is_test(true).try_init();
966
967            let m = model(graph(
968                vec![tensor("X", &[1])],
969                vec![tensor("Y", &[1])],
970                vec![tensor("A", &[1]), tensor("B", &[1])],
971                vec![],
972                vec![
973                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
974                    node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
975                    node(vec!["B"], vec!["Y"], "c", "Neg", vec![]),
976                ],
977            ));
978
979            let root = ir::Node::from_model(&m, None).unwrap();
980            let mut opt = Optimizer::new(13);
981            let new_root = opt.optimize(root).await.unwrap();
982            let mut new_pairs = vec![];
983            traverse(new_root, &mut new_pairs);
984            assert_eq!(
985                new_pairs,
986                vec![
987                    ("Neg_c".to_string(), "<outputs>".to_string()),
988                    ("X".to_string(), "Neg_c".to_string())
989                ]
990            );
991        });
992    }
993
994    // Test: X -> [Neg] A -> [Neg] B -> [Neg] C -> [Neg] -> Y => X -> Identity -> Y
995    #[test]
996    pub fn test_optimize_4neg() {
997        let _ = env_logger::builder().is_test(true).try_init();
998        pollster::block_on(async {
999            let m = model(graph(
1000                vec![tensor("X", &[1])],
1001                vec![tensor("Y", &[1])],
1002                vec![tensor("A", &[1]), tensor("B", &[1]), tensor("C", &[1])],
1003                vec![],
1004                vec![
1005                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1006                    node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
1007                    node(vec!["B"], vec!["C"], "c", "Neg", vec![]),
1008                    node(vec!["C"], vec!["Y"], "d", "Neg", vec![]),
1009                ],
1010            ));
1011
1012            let root = ir::Node::from_model(&m, None).unwrap();
1013            let mut opt = Optimizer::new(13);
1014            let new_root = opt.optimize(root).await.unwrap();
1015            let mut new_pairs = vec![];
1016            traverse(new_root, &mut new_pairs);
1017            assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string()),]);
1018        });
1019    }
1020
1021    // Test: X -> [Neg] A -> [Neg] B -> [Neg] C -> [Neg] D -> [Neg] -> Y => X -> Neg -> Y
1022    #[test]
1023    pub fn test_optimize_5neg() {
1024        let _ = env_logger::builder().is_test(true).try_init();
1025        pollster::block_on(async {
1026            let m = model(graph(
1027                vec![tensor("X", &[1])],
1028                vec![tensor("Y", &[1])],
1029                vec![
1030                    tensor("A", &[1]),
1031                    tensor("B", &[1]),
1032                    tensor("C", &[1]),
1033                    tensor("D", &[1]),
1034                ],
1035                vec![],
1036                vec![
1037                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1038                    node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
1039                    node(vec!["B"], vec!["C"], "c", "Neg", vec![]),
1040                    node(vec!["C"], vec!["D"], "d", "Neg", vec![]),
1041                    node(vec!["D"], vec!["Y"], "e", "Neg", vec![]),
1042                ],
1043            ));
1044
1045            let root = ir::Node::from_model(&m, None).unwrap();
1046            let mut opt = Optimizer::new(13);
1047            let new_root = opt.optimize(root).await.unwrap();
1048            let mut new_pairs = vec![];
1049            traverse(new_root, &mut new_pairs);
1050            assert_eq!(
1051                new_pairs,
1052                vec![
1053                    ("Neg_e".to_string(), "<outputs>".to_string()),
1054                    ("X".to_string(), "Neg_e".to_string())
1055                ]
1056            );
1057        });
1058    }
1059
1060    // Test: X -> [Neg] A -> [Neg] -> A, Y => X -> A, Y
1061    #[test]
1062    pub fn test_optimize_neg_neg_branch() {
1063        let _ = env_logger::builder().is_test(true).try_init();
1064        pollster::block_on(async {
1065            let m = model(graph(
1066                vec![tensor("X", &[1])],
1067                vec![tensor("Y", &[1]), tensor("A", &[1])],
1068                vec![tensor("A", &[1])],
1069                vec![],
1070                vec![
1071                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1072                    node(vec!["A"], vec!["Y"], "b", "Neg", vec![]),
1073                ],
1074            ));
1075
1076            let root = ir::Node::from_model(&m, None).unwrap();
1077            let mut opt = Optimizer::new(13);
1078            let new_root = opt.optimize(root).await.unwrap();
1079            let mut new_pairs = vec![];
1080            traverse(new_root, &mut new_pairs);
1081            assert_eq!(
1082                new_pairs,
1083                vec![
1084                    ("X".to_string(), "<outputs>".to_string()),
1085                    ("Neg_a".to_string(), "<outputs>".to_string()),
1086                    ("X".to_string(), "Neg_a".to_string())
1087                ]
1088            );
1089        });
1090    }
1091
1092    // Test: X -> [Neg] A -> [Identity] Z -> [Identity] -> Y with Y and Z output => X -> Y, Z
1093    #[test]
1094    pub fn test_optimize_identity_identity_two_outputs() {
1095        let _ = env_logger::builder().is_test(true).try_init();
1096
1097        pollster::block_on(async {
1098            let m = model(graph(
1099                vec![tensor("X", &[1])],
1100                vec![tensor("Y", &[1]), tensor("Z", &[1])],
1101                vec![tensor("A", &[1])],
1102                vec![],
1103                vec![
1104                    node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1105                    node(vec!["A"], vec!["Z"], "b", "Identity", vec![]),
1106                    node(vec!["A"], vec!["Y"], "c", "Identity", vec![]),
1107                ],
1108            ));
1109
1110            let root = ir::Node::from_model(&m, None).unwrap();
1111            let mut opt = Optimizer::new(13);
1112            let new_root = opt.optimize(root).await.unwrap();
1113            let mut new_pairs = vec![];
1114            traverse(new_root, &mut new_pairs);
1115            assert_eq!(
1116                new_pairs,
1117                vec![
1118                    ("Neg_a".to_string(), "<outputs>".to_string()),
1119                    ("Neg_a".to_string(), "<outputs>".to_string()),
1120                    ("X".to_string(), "Neg_a".to_string()),
1121                    ("X".to_string(), "Neg_a".to_string()),
1122                ]
1123            );
1124        });
1125    }
1126
1127    // Test: A, B -> [Add] -> C where A, B are initializers
1128    #[test]
1129    pub fn test_constant_folding() {
1130        let _ = env_logger::builder().is_test(true).try_init();
1131
1132        pollster::block_on(async {
1133            let m = model(graph(
1134                vec![],
1135                vec![tensor("C", &[1])],
1136                vec![],
1137                vec![
1138                    initializer("A", vec![21.0], vec![1]),
1139                    initializer("B", vec![7.0], vec![1]),
1140                ],
1141                vec![node(vec!["A", "B"], vec!["C"], "c", "Add", vec![])],
1142            ));
1143
1144            let root = ir::Node::from_model(&m, None).unwrap();
1145            let mut opt = Optimizer::new(13);
1146            let new_root = opt.optimize(root).await.unwrap();
1147            let mut new_pairs = vec![];
1148            traverse(new_root, &mut new_pairs);
1149            assert_eq!(new_pairs, vec![("C".to_string(), "<outputs>".to_string())]);
1150        });
1151    }
1152
1153    // Test: [Constant] -> Y => [initializer] -> Y
1154    #[test]
1155    pub fn test_constant_node_to_tensor() {
1156        let _ = env_logger::builder().is_test(true).try_init();
1157
1158        pollster::block_on(async {
1159            let m = model(graph(
1160                vec![],
1161                vec![tensor("Y", &[1])],
1162                vec![],
1163                vec![],
1164                vec![node(
1165                    vec![],
1166                    vec!["Y"],
1167                    "y",
1168                    "Constant",
1169                    vec![attribute("value_float", 42.0)],
1170                )],
1171            ));
1172
1173            let root = ir::Node::from_model(&m, None).unwrap();
1174            let mut opt = Optimizer::new(13);
1175            let new_root = opt.optimize(root).await.unwrap();
1176            let mut new_pairs = vec![];
1177            traverse(new_root.clone(), &mut new_pairs);
1178            assert_eq!(new_pairs, vec![("Y".to_string(), "<outputs>".to_string())]);
1179
1180            let y_node = new_root.inputs[0].source_node.clone();
1181            assert!(matches!(y_node.definition(), NodeDefinition::Tensor(_)));
1182        });
1183    }
1184
1185    // Test: Input X -> [Shape] -> Y => [initializer] -> Y with initializer containing the correct shape of input X
1186    #[test]
1187    pub fn test_shape_operator() {
1188        test_shape_operator_with(
1189            &[1, 2, 3],
1190            vec![attribute("start", -3), attribute("end", -2)],
1191            &[1],
1192        );
1193        test_shape_operator_with(&[1, 2, 3], vec![], &[1, 2, 3]);
1194        test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 0)], &[3, 4, 5]);
1195        test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 1)], &[4, 5]);
1196        test_shape_operator_with(&[3, 4, 5], vec![attribute("start", -1)], &[5]);
1197        test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 10)], &[3, 4, 5]);
1198        test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 1)], &[3]);
1199        test_shape_operator_with(
1200            &[3, 4, 5],
1201            vec![attribute("start", 10), attribute("end", 10)],
1202            &[],
1203        );
1204    }
1205
1206    pub fn test_shape_operator_with(
1207        input_shape: &[i64],
1208        attrs: Vec<AttributeProto>,
1209        expected: &[i64],
1210    ) {
1211        let _ = env_logger::builder().is_test(true).try_init();
1212
1213        pollster::block_on(async {
1214            let m = model(graph(
1215                vec![tensor("X", input_shape)],
1216                vec![tensor("Y", &[expected.len() as i64])],
1217                vec![],
1218                vec![],
1219                vec![node(vec!["X"], vec!["Y"], "y", "Shape", attrs)],
1220            ));
1221
1222            let root = ir::Node::from_model(&m, None).unwrap();
1223            let mut opt = Optimizer::new(13);
1224            let new_root = opt.optimize(root).await.unwrap();
1225            let mut new_pairs = vec![];
1226            traverse(new_root.clone(), &mut new_pairs);
1227            assert_eq!(new_pairs, vec![("".to_string(), "<outputs>".to_string())]);
1228
1229            let y_node = new_root.inputs[0].source_node.clone();
1230            let NodeDefinition::Tensor(t) = y_node.definition() else {
1231                panic!("should be folded to an initializer");
1232            };
1233            assert_eq!(t.get_int64_data(), expected);
1234        });
1235    }
1236
1237    // Test: Input X -> [Size] -> Y => [initializer] -> Y with initializer containing the correct shape of input X
1238    #[test]
1239    pub fn test_size_operator() {
1240        test_size_operator_with(&[1, 2, 3], &[6]);
1241        test_size_operator_with(&[1], &[1]);
1242        test_size_operator_with(&[], &[1]);
1243    }
1244
1245    pub fn test_size_operator_with(input_shape: &[i64], expected: &[i64]) {
1246        let _ = env_logger::builder().is_test(true).try_init();
1247
1248        pollster::block_on(async {
1249            let m = model(graph(
1250                vec![tensor("X", input_shape)],
1251                vec![tensor("Y", &[expected.len() as i64])],
1252                vec![],
1253                vec![],
1254                vec![node(vec!["X"], vec!["Y"], "y", "Size", vec![])],
1255            ));
1256
1257            let root = ir::Node::from_model(&m, None).unwrap();
1258            let mut opt = Optimizer::new(13);
1259            let new_root = opt.optimize(root).await.unwrap();
1260            let mut new_pairs = vec![];
1261            traverse(new_root.clone(), &mut new_pairs);
1262            assert_eq!(new_pairs, vec![("".to_string(), "<outputs>".to_string())]);
1263
1264            let y_node = new_root.inputs[0].source_node.clone();
1265            let NodeDefinition::Tensor(t) = y_node.definition() else {
1266                panic!("should be folded to an initializer");
1267            };
1268            assert_eq!(t.get_int64_data(), expected);
1269        });
1270    }
1271}