1use crate::{
3 gpu::GpuModel,
4 ir::{Input, Node, NodeDefinition, NodeIdentifier, OperatorDefinition},
5 onnx::{NodeProto, TensorProto},
6 resource::{padding, request_device_queue},
7 utils::{
8 attribute, AttributeNotFoundError, DataTypeError, NodeAttributes, OutputTensor, ScalarType,
9 Shape,
10 },
11 GpuError,
12};
13use async_recursion::async_recursion;
14use bytemuck::pod_collect_to_vec;
15use protobuf::RepeatedField;
16use std::{
17 borrow::Cow,
18 collections::{HashMap, VecDeque},
19 sync::Arc,
20};
21use thiserror::Error;
22
23#[derive(Debug, Error)]
24pub enum OptimizerError {
25 #[error("node has no inputs")]
26 NoInputs,
27
28 #[error("unsupported: {0}")]
29 Unsupported(String),
30
31 #[error("invalid data type {data_type:?} for input {input} of op {op}")]
32 InvalidInputDataType {
33 data_type: ScalarType,
34 input: String,
35 op: String,
36 },
37
38 #[error("error with data type: {0}")]
39 InvalidDataType(#[from] DataTypeError),
40
41 #[error("node is invalid: {0}")]
42 InvalidNode(String),
43
44 #[error("required attribute not found: {0}")]
45 AttributeNotFound(#[from] AttributeNotFoundError),
46
47 #[error("error during constant folding: {0}")]
48 ConstantFoldingError(#[from] GpuError),
49}
50
51pub struct Optimizer<'model> {
52 padded_tensors: HashMap<String, Arc<Node<'model>>>,
53 optimized: HashMap<NodeIdentifier<'model>, Arc<Node<'model>>>,
54 onnx_opset_version: i64,
55}
56
57impl<'model> Optimizer<'model> {
58 pub fn new(onnx_opset_version: i64) -> Self {
59 Self {
60 padded_tensors: HashMap::new(),
61 optimized: HashMap::new(),
62 onnx_opset_version,
63 }
64 }
65
66 async fn fold_constant_node(
68 &self,
69 node: Arc<Node<'model>>,
70 ) -> Result<Option<Arc<Node<'model>>>, OptimizerError> {
71 assert!(node.is_constant());
72
73 match node.definition() {
74 NodeDefinition::Operator(op_def) => {
75 if op_def.proto.output.len() != 1 {
77 log::warn!(
78 "node {:?} is constant, but has multiple outputs, which we can't fold yet",
79 node.definition()
80 );
81 return Ok(None);
82 }
83
84 match op_def.proto.get_op_type() {
85 "Constant" => Ok(Some(Arc::new(Node {
86 definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
87 Self::constant_node_to_tensor(node)?,
88 ))),
89 inputs: vec![],
90 }))),
91 _ => self.infer_constant_node_to_tensor(node.clone()).await,
92 }
93 }
94 NodeDefinition::Tensor(_) => Ok(None), NodeDefinition::Input(_) | NodeDefinition::Missing => unreachable!(),
96 NodeDefinition::Outputs { .. } => Ok(None), }
98 }
99
100 fn shape_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
102 let NodeDefinition::Operator(op_def) = node.definition() else {
103 panic!("node must be a Shape node");
104 };
105 assert_eq!(op_def.proto.get_op_type(), "Shape");
106
107 if node.inputs.len() != 1 {
108 return Err(OptimizerError::InvalidNode(format!(
109 "Shape node should only have one input, has {}",
110 node.inputs.len()
111 )));
112 }
113
114 let input = &node.inputs[0];
116 let in_node = &input.source_node.definition;
117 let in_shape = match in_node {
118 NodeDefinition::Input(input) => input.get_shape()?,
119 NodeDefinition::Operator(input_op_def) => {
120 input_op_def.output_shapes[input.output_index].clone()
121 }
122 NodeDefinition::Tensor(input_tensor) => Shape::from(
123 ScalarType::from_i32(input_tensor.get_data_type())
124 .map_err(OptimizerError::InvalidDataType)?,
125 input_tensor.get_dims(),
126 ),
127 NodeDefinition::Outputs { .. } => {
128 return Err(OptimizerError::Unsupported(
129 "output node cannot be used as an input to Shape node".to_string(),
130 ))
131 }
132 NodeDefinition::Missing => {
133 return Err(OptimizerError::InvalidNode(
134 "Shape node has missing input".to_string(),
135 ))
136 }
137 };
138 let rank = in_shape.rank() as i64;
139 let mut start: i64 = op_def.proto.get_attribute_value("start", Some(0)).unwrap();
140 let mut end: i64 = op_def.proto.get_attribute_value("end", Some(rank)).unwrap();
141 if start < 0 {
142 start += rank;
143 }
144 if end < 0 {
145 end += rank;
146 }
147 start = start.clamp(0, rank);
148 end = end.clamp(0, rank);
149
150 if start < 0 || start > rank {
151 return Err(OptimizerError::InvalidNode(format!(
152 "start index of Shape node cannot be below zero, found {start}"
153 )));
154 }
155
156 if end < 0 || end > rank || end < start {
157 return Err(OptimizerError::InvalidNode(format!(
158 "end index of Shape node cannot be below zero or higher than {rank} or below start {start}, found {end}"
159 )));
160 }
161
162 let values: Vec<i64> = in_shape.dims[(start as usize)..=((end - 1) as usize)]
163 .iter()
164 .map(|x| *x as i64)
165 .collect();
166 let dims = vec![values.len() as i64];
167 Ok(TensorProto::from(OutputTensor::I64(values), dims))
168 }
169
170 fn constant_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
172 let NodeDefinition::Operator(op_def) = node.definition() else {
173 panic!("node must be a Constant node");
174 };
175 assert_eq!(op_def.proto.get_op_type(), "Constant");
176 let proto = &op_def.proto;
177 let output_name = proto.output.get(0).unwrap().to_owned();
178
179 let mut tp: TensorProto =
180 if let Ok(values) = proto.get_attribute_value::<Vec<f32>>("value_floats", None) {
181 let dims = vec![values.len() as i64];
182 TensorProto::from(OutputTensor::F32(values), dims)
183 } else if let Ok(values) = proto.get_attribute_value::<Vec<i64>>("value_ints", None) {
184 let dims = vec![values.len() as i64];
185 TensorProto::from(OutputTensor::I64(values), dims)
186 } else if let Ok(value) = proto.get_attribute_value::<f32>("value_float", None) {
187 TensorProto::from(OutputTensor::F32(vec![value]), vec![1])
188 } else if let Ok(value) = proto.get_attribute_value::<i64>("value_int", None) {
189 TensorProto::from(OutputTensor::I64(vec![value]), vec![1])
190 } else if let Ok(tp) = proto.get_attribute_value::<TensorProto>("value", None) {
191 tp
192 } else {
193 return Err(OptimizerError::Unsupported(
194 "Constant node with unknown value type".to_string(),
195 ));
196 };
197
198 tp.set_name(output_name);
199 Ok(tp)
200 }
201
202 fn size_node_to_tensor(node: Arc<Node<'model>>) -> Result<TensorProto, OptimizerError> {
204 let NodeDefinition::Operator(op_def) = node.definition() else {
205 panic!("node must be a Size node");
206 };
207 assert_eq!(op_def.proto.get_op_type(), "Size");
208
209 if node.inputs.len() != 1 {
210 return Err(OptimizerError::InvalidNode(format!(
211 "Size node should only have one input, has {}",
212 node.inputs.len()
213 )));
214 }
215
216 let input = &node.inputs[0];
218 let in_node = &input.source_node.definition;
219 let in_element_count: i64 = match in_node {
220 NodeDefinition::Input(input) => input.get_shape()?.element_count() as i64,
221 NodeDefinition::Operator(input_op_def) => {
222 input_op_def.output_shapes[input.output_index].element_count() as i64
223 }
224 NodeDefinition::Tensor(input_tensor) => input_tensor.get_dims().iter().product(),
225 NodeDefinition::Outputs { .. } => {
226 return Err(OptimizerError::Unsupported(
227 "output node cannot be used as an input to Shape node".to_string(),
228 ))
229 }
230 NodeDefinition::Missing => {
231 return Err(OptimizerError::InvalidNode(
232 "Shape node has missing input".to_string(),
233 ))
234 }
235 };
236
237 Ok(TensorProto::from(
238 OutputTensor::I64(vec![in_element_count]),
239 vec![1],
240 ))
241 }
242
243 async fn infer_constant_node_to_tensor(
245 &self,
246 node: Arc<Node<'model>>,
247 ) -> Result<Option<Arc<Node<'model>>>, OptimizerError> {
248 assert!(node.is_constant());
249
250 if let NodeDefinition::Operator(op_def) = node.definition() {
252 let output_name = op_def.proto.output.get(0).unwrap().to_owned();
253
254 let out_node = Arc::new(Node {
255 definition: NodeDefinition::Outputs {
256 names: vec!["output".to_string()],
257 },
258 inputs: vec![Input {
259 source_node: node.clone(),
260 output_index: 0,
261 }],
262 });
263
264 let (device, queue) = request_device_queue().await;
266 let gm = GpuModel::from(out_node, device, queue, self.onnx_opset_version)
267 .map_err(OptimizerError::ConstantFoldingError)?;
268 let mut outputs = gm.infer(&HashMap::new()).await?;
269
270 let (_, output_tensor) = outputs.drain().take(1).next().unwrap();
272 log::info!("folded {output_name} to {output_tensor:?}");
273 let mut output_tensor_proto = TensorProto::from(
274 output_tensor,
275 op_def.output_shapes[0]
276 .dims
277 .iter()
278 .map(|x| *x as i64)
279 .collect(),
280 );
281 output_tensor_proto.set_name(output_name);
282
283 let tensor_node = Node {
284 definition: NodeDefinition::Tensor(Box::new(Cow::Owned(output_tensor_proto))),
285 inputs: vec![],
286 };
287
288 Ok(Some(Arc::new(tensor_node)))
289 } else {
290 panic!("node to fold must be operator")
291 }
292 }
293
294 #[async_recursion]
296 pub async fn optimize(
297 &mut self,
298 node: Arc<Node<'model>>,
299 ) -> Result<Arc<Node<'model>>, OptimizerError> {
300 let identifier = node.identifier();
301 match self.optimized.get(&identifier) {
302 Some(opt_node) => Ok(opt_node.clone()),
303 None => {
304 let opt_node = self.optimize_actual(node).await?;
305 self.optimized.insert(identifier, opt_node.clone());
306 Ok(opt_node)
307 }
308 }
309 }
310
311 #[async_recursion]
314 async fn optimize_actual(
315 &mut self,
316 node: Arc<Node<'model>>,
317 ) -> Result<Arc<Node<'model>>, OptimizerError> {
318 let prior;
320 let mut chain = VecDeque::new();
321 chain.push_back(node.clone());
322
323 loop {
324 let head = chain.front().unwrap();
325 let dynamic_inputs = head
326 .inputs
327 .iter()
328 .filter(|input| input.source_node.is_dynamic() && input.output_index == 0)
329 .collect::<Vec<&Input>>();
330
331 if dynamic_inputs.len() != 1 {
332 prior = chain.pop_front().unwrap();
333 break;
334 }
335 chain.push_front(dynamic_inputs[0].source_node.clone());
336 }
337
338 log::debug!(
339 "optimize: node={:?} def={:?} chain={}, next={:?}",
340 node.identifier(),
341 node.definition,
342 chain
343 .iter()
344 .map(|x| format!("[{:?}]", x.definition))
345 .collect::<Vec<String>>()
346 .join(" -> "),
347 prior.identifier()
348 );
349
350 if chain.len() > 1 {
352 let mut final_chain: Vec<Arc<Node>> = vec![];
353 while !chain.is_empty() {
354 log::debug!("optimize chain {}", chain.len());
355 while self.optimize_chain(&mut chain)? {
356 log::debug!("optimize chain succeeded {}", chain.len());
357 }
358
359 if !chain.is_empty() {
360 let first = chain.pop_front().unwrap();
362 final_chain.push(first);
363 }
364
365 log::debug!(
366 "optimized chain: {}",
367 final_chain
368 .iter()
369 .map(|x| format!("[{:?}]", x.definition))
370 .collect::<Vec<String>>()
371 .join(" -> ")
372 );
373 }
374 drop(chain);
375
376 let optimized_next = self.optimize(prior).await?;
378
379 if final_chain.is_empty() {
380 return Ok(optimized_next);
381 }
382
383 for node_index in 0..=(final_chain.len() - 1) {
385 let consumer = final_chain[node_index].clone();
386 let producer = if node_index == 0 {
387 optimized_next.clone()
388 } else {
389 final_chain[node_index - 1].clone()
390 };
391 final_chain[node_index] = self
392 .locally_optimized_node_with(
393 consumer.clone(),
394 consumer
395 .inputs
396 .iter()
397 .map(|old_input| {
398 let is_dynamic_source = old_input.source_node.is_dynamic()
400 && old_input.output_index == 0;
401 if is_dynamic_source {
402 Input {
403 source_node: producer.clone(),
404 output_index: 0,
405 }
406 } else {
407 old_input.clone()
408 }
409 })
410 .collect(),
411 )
412 .await?;
413 }
414
415 Ok(final_chain.last().unwrap().clone())
416 } else {
417 let mut new_inputs = Vec::with_capacity(node.inputs.len());
419 for input in node.inputs.iter() {
420 new_inputs.push(Input {
421 source_node: self.optimize(input.source_node.clone()).await?,
422 output_index: input.output_index,
423 });
424 }
425 self.locally_optimized_node_with(node.clone(), new_inputs)
426 .await
427 }
428 }
429
430 async fn locally_optimized_node_with(
432 &mut self,
433 node: Arc<Node<'model>>,
434 mut new_inputs: Vec<Input<'model>>,
435 ) -> Result<Arc<Node<'model>>, OptimizerError> {
436 log::debug!(
437 "locally_optimized_node_with {:?} {:?}",
438 node.identifier(),
439 node.definition()
440 );
441
442 if let NodeDefinition::Operator(op_def) = &node.definition {
444 match op_def.proto.get_op_type() {
445 "Shape" => {
446 return Ok(Arc::new(Node {
447 definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
448 Self::shape_node_to_tensor(node)?,
449 ))),
450 inputs: vec![],
451 }))
452 }
453 "Size" => {
454 return Ok(Arc::new(Node {
455 definition: NodeDefinition::Tensor(Box::new(Cow::Owned(
456 Self::size_node_to_tensor(node)?,
457 ))),
458 inputs: vec![],
459 }))
460 }
461 _ => {}
462 }
463 }
464
465 if node.is_constant() && !matches!(node.definition, NodeDefinition::Missing) {
467 log::debug!(
468 "node is constant: {:?} {:?}",
469 node.identifier(),
470 node.definition()
471 );
472 if let Some(const_node) = self.fold_constant_node(node.clone()).await? {
473 return Ok(const_node);
474 }
475 }
476
477 match &node.definition {
478 NodeDefinition::Operator(op_def) => {
479 match op_def.proto.get_op_type() {
480 "Conv" | "ConvRelu" | "ConvLeakyRelu" => {
481 if new_inputs.len() > 2
484 && op_def
485 .proto
486 .get_attribute_value::<Vec<i64>>("kernel_shape", None)?
487 == [3, 3]
488 && (op_def
489 .proto
490 .get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?
491 == [1, 1, 1, 1]
492 || op_def.proto.get_attribute_value(
493 "auto_pad",
494 Some("SAME_UPPER".to_string()),
495 )? == "SAME_UPPER")
496 && op_def
497 .proto
498 .get_attribute_value("strides", Some(vec![1, 1]))?
499 == [1, 1]
500 && op_def.proto.get_attribute_value("group", Some(1))? == 1
501 && op_def.output_shapes[0].dim(1) % 4 == 0
502 {
503 if let NodeDefinition::Tensor(tensor) =
504 &new_inputs[1].source_node.definition
505 {
506 new_inputs[1] = Input {
507 output_index: 0,
508 source_node: match self.padded_tensors.get(tensor.get_name()) {
509 Some(padded_tensor_node) => padded_tensor_node.clone(),
510 None => {
511 let data = tensor.get_float_data();
512 let raw_data = if !data.is_empty() {
513 bytemuck::cast_slice(data)
514 } else {
515 tensor.get_raw_data()
516 };
517
518 let padded_raw_data = padding(raw_data, 12, 4);
519
520 log::info!(
521 "applying padding optimization to tensor {}: strides data is {} bytes before, {} bytes after",
522 tensor.get_name(),
523 raw_data.len(),
524 padded_raw_data.len()
525 );
526
527 let mut new_tensor = tensor.clone().into_owned();
529 new_tensor.set_float_data(vec![]);
530 new_tensor.set_raw_data(padded_raw_data);
531 let new_node = Arc::new(Node {
532 definition: NodeDefinition::Tensor(Box::new(
533 Cow::Owned(new_tensor),
534 )),
535 inputs: vec![],
536 });
537 self.padded_tensors.insert(
538 tensor.get_name().to_string(),
539 new_node.clone(),
540 );
541 new_node
542 }
543 },
544 }
545 }
546 }
547
548 let new_node = Node {
549 inputs: new_inputs,
550 definition: NodeDefinition::Operator(op_def.clone()),
551 };
552
553 Ok(Arc::new(new_node))
554 }
555
556 op @ ("Clip" | "Pad" | "Split" | "Resize" | "Reshape" | "ReduceMean"
560 | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
561 | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
562 | "ReduceProd") => {
563 if new_inputs.is_empty() {
564 return Err(OptimizerError::NoInputs);
565 }
566
567 let attr_names = match op {
569 "Split" => SPLIT_INPUT_NAMES,
570 "Resize" => RESIZE_INPUT_NAMES,
571 "Reshape" => RESHAPE_INPUT_NAMES,
572 "Clip" => CLIP_INPUT_NAMES,
573 "Pad" => PAD_INPUT_NAMES,
574 "ReduceSum" => REDUCE_OPS_INPUT_NAMES,
575 "ReduceL1" => REDUCE_OPS_INPUT_NAMES,
576 "ReduceL2" => REDUCE_OPS_INPUT_NAMES,
577 "ReduceLogSum" => REDUCE_OPS_INPUT_NAMES,
578 "ReduceLogSumExp" => REDUCE_OPS_INPUT_NAMES,
579 "ReduceMax" => REDUCE_OPS_INPUT_NAMES,
580 "ReduceMean" => REDUCE_OPS_INPUT_NAMES,
581 "ReduceMin" => REDUCE_OPS_INPUT_NAMES,
582 "ReduceProd" => REDUCE_OPS_INPUT_NAMES,
583 "ReduceSumSquare" => REDUCE_OPS_INPUT_NAMES,
584 _ => unreachable!(),
585 };
586
587 let mut new_proto = op_def.proto.clone().into_owned();
589 let mut attributes = op_def.proto.get_attribute().to_vec();
590
591 for input_index in 1..(new_inputs.len().min(attr_names.len())) {
593 let source_node = &new_inputs[input_index].source_node;
594 match &source_node.definition {
595 NodeDefinition::Tensor(tensor_proto) => {
597 let attr_name = attr_names[input_index];
598 let data_type =
599 ScalarType::from_i32(tensor_proto.get_data_type())?;
600
601 match (op, attr_name) {
602 ("Split", "split")
603 | ("Resize", "roi")
604 | ("Resize", "sizes")
605 | ("Reshape", "shape")
606 | (
607 "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax"
608 | "ReduceSumSquare" | "ReduceLogSumExp"
609 | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
610 | "ReduceProd",
611 "axes",
612 )
613 | ("Pad", "pads")
614 | ("Resize", "scales")
615 | ("Clip", "min" | "max") => match data_type {
616 ScalarType::F32 => {
617 let value: Vec<f32> = if tensor_proto
618 .get_float_data()
619 .is_empty()
620 {
621 pod_collect_to_vec(tensor_proto.get_raw_data())
622 } else {
623 tensor_proto.get_float_data().to_vec()
624 };
625 log::info!(
626 "transferring input {} for op {} to f32 attribute (initializer data type: {:?}): {:?}",
627 attr_name,
628 op,
629 data_type,
630 value,
631 );
632 attributes.push(attribute(
633 attr_names[input_index],
634 value,
635 ));
636 }
637 ScalarType::I64 => {
638 let value = if tensor_proto
639 .get_int64_data()
640 .is_empty()
641 {
642 pod_collect_to_vec(tensor_proto.get_raw_data())
643 } else {
644 tensor_proto.get_int64_data().to_vec()
645 };
646 log::info!(
647 "transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}",
648 attr_name,
649 op,
650 data_type,
651 value,
652 );
653 attributes.push(attribute(
654 attr_names[input_index],
655 value,
656 ));
657 }
658 _ => {
659 return Err(OptimizerError::InvalidInputDataType {
660 data_type,
661 input: attr_name.to_string(),
662 op: op.to_string(),
663 })
664 }
665 },
666 _ => {
667 return Err(OptimizerError::Unsupported(format!(
669 "data_type {} for input {} to op {}",
670 tensor_proto.get_data_type(),
671 attr_name,
672 op
673 )));
674 }
675 }
676 }
677 NodeDefinition::Missing => {
678 }
680 _ => {
681 return Err(OptimizerError::Unsupported(format!(
683 "{} operation with dynamic input for {}",
684 op, attr_names[input_index]
685 )));
686 }
687 }
688 }
689
690 new_proto.set_attribute(RepeatedField::from(attributes));
692
693 let new_node = Node {
694 inputs: vec![new_inputs[0].clone()],
695 definition: NodeDefinition::Operator(Box::new(OperatorDefinition {
696 proto: Cow::Owned(new_proto),
697 output_shapes: op_def.output_shapes.clone(),
698 })),
699 };
700
701 Ok(Arc::new(new_node))
702 }
703
704 _ => Ok(Arc::new(Node {
705 inputs: new_inputs,
706 definition: NodeDefinition::Operator(op_def.clone()),
707 })),
708 }
709 }
710 NodeDefinition::Tensor(..) | NodeDefinition::Input(..) => {
711 assert!(
712 new_inputs.is_empty(),
713 "non-operator node cannot have inputs"
714 );
715 Ok(node.clone())
717 }
718 &NodeDefinition::Outputs { .. } => Ok(Arc::new(Node {
719 inputs: new_inputs,
720 definition: node.definition().clone(),
721 })),
722 NodeDefinition::Missing => Ok(node.clone()),
723 }
724 }
725
726 fn optimize_chain(
730 &mut self,
731 chain: &mut VecDeque<Arc<Node<'model>>>,
732 ) -> Result<bool, OptimizerError> {
733 chain.retain(|n| match &n.definition {
735 NodeDefinition::Operator(op_def) => op_def.proto.get_op_type() != "Identity",
736 _ => true,
737 });
738
739 let names: Vec<&str> = chain
740 .iter()
741 .map(|x| match &x.definition {
742 NodeDefinition::Operator(op_def) => op_def.proto.get_op_type(),
743 _ => "",
744 })
745 .collect();
746
747 log::debug!("optimize_chain {:?}", names);
748
749 match &names[..] {
750 ["Neg", "Neg", ..] => {
752 chain.pop_front();
753 chain.pop_front();
754 Ok(true)
755 }
756
757 ["Conv", "Relu", ..] | ["Conv", "LeakyRelu", ..] => {
759 let conv = chain[0].clone();
760 let relu = chain[1].clone();
761
762 if let (NodeDefinition::Operator(conv_def), NodeDefinition::Operator(relu_def)) =
763 (&conv.definition, &relu.definition)
764 {
765 let mut convrelu_def = *conv_def.clone();
767 let mut convrelu_proto = conv_def.proto.clone().into_owned();
768 let new_op_type = match relu_def.proto.get_op_type() {
769 "LeakyRelu" => "ConvLeakyRelu",
770 "Relu" => "ConvRelu",
771 _ => unreachable!(),
772 };
773 convrelu_proto.set_op_type(new_op_type.to_string());
774
775 let mut attributes = conv_def.proto.get_attribute().to_vec();
777 attributes.extend(relu_def.proto.get_attribute().iter().cloned());
778 convrelu_proto.set_attribute(RepeatedField::from(attributes));
779 convrelu_proto.set_name(format!(
780 "{}+{}",
781 conv.definition.get_name(),
782 relu.definition.get_name()
783 ));
784
785 log::debug!(
786 "can fuse chain of Conv/[Leaky]Relu to Conv[Leaky]Relu: {:?}: {:?} + {:?} = {}",
787 names,
788 conv.definition(),
789 relu.definition(),
790 convrelu_proto.get_name()
791 );
792
793 convrelu_def.proto = Cow::Owned(convrelu_proto);
794
795 let node = Arc::new(Node {
796 inputs: conv.inputs.clone(),
797 definition: NodeDefinition::Operator(Box::new(convrelu_def)),
798 });
799
800 chain.remove(0);
801 chain.remove(0);
802 chain.insert(0, node);
803 Ok(true)
804 } else {
805 unreachable!();
806 }
807 }
808 _ => Ok(false),
809 }
810 }
811}
812
813static SPLIT_INPUT_NAMES: &[&str] = &["input", "split"];
815static RESIZE_INPUT_NAMES: &[&str] = &["X", "roi", "scales", "sizes"];
816static RESHAPE_INPUT_NAMES: &[&str] = &["data", "shape"];
817static CLIP_INPUT_NAMES: &[&str] = &["input", "min", "max"];
818static REDUCE_OPS_INPUT_NAMES: &[&str] = &["input", "axes"];
819static PAD_INPUT_NAMES: &[&str] = &["data", "pads", "constant_value"];
820
821pub fn constant_of_shape_output(
823 node: &NodeProto,
824 element_count: usize,
825) -> Result<OutputTensor, OptimizerError> {
826 if let Ok(constant_value_tensor) = node.get_attribute_value::<TensorProto>("value", None) {
827 match ScalarType::from_i32(constant_value_tensor.get_data_type()).map_err(|_| {
828 OptimizerError::Unsupported(format!(
829 "unsupported data type {}",
830 constant_value_tensor.get_data_type()
831 ))
832 })? {
833 ScalarType::F32 => {
834 let fd = constant_value_tensor.get_float_data();
835 if fd.is_empty() {
836 return Err(OptimizerError::InvalidNode(
837 "value tensor for ConstantOfShape is empty".to_string(),
838 ));
839 }
840 Ok(OutputTensor::F32(vec![fd[0]; element_count]))
841 }
842 ScalarType::I64 => {
843 let fd = constant_value_tensor.get_int64_data();
844 if fd.is_empty() {
845 return Err(OptimizerError::InvalidNode(
846 "value tensor for ConstantOfShape is empty".to_string(),
847 ));
848 }
849 Ok(OutputTensor::I64(vec![fd[0]; element_count]))
850 }
851 ScalarType::I32 => {
852 let fd = constant_value_tensor.get_int32_data();
853 if fd.is_empty() {
854 return Err(OptimizerError::InvalidNode(
855 "value tensor for ConstantOfShape is empty".to_string(),
856 ));
857 }
858 Ok(OutputTensor::I32(vec![fd[0]; element_count]))
859 }
860 ScalarType::U8 => {
861 let fd = constant_value_tensor.get_raw_data();
862 if fd.is_empty() {
863 return Err(OptimizerError::InvalidNode(
864 "value tensor for ConstantOfShape is empty".to_string(),
865 ));
866 }
867 Ok(OutputTensor::U8(vec![fd[0]; element_count]))
868 }
869 }
870 } else {
871 Ok(OutputTensor::F32(vec![0.0; element_count]))
873 }
874}
875
876#[cfg(test)]
877mod test {
878 use std::sync::Arc;
879
880 use crate::{
881 ir::{self, Node, NodeDefinition},
882 onnx::AttributeProto,
883 utils::{attribute, graph, initializer, model, node, tensor},
884 };
885
886 use super::Optimizer;
887
888 fn friendly_name(node: Arc<Node>) -> String {
889 match node.definition() {
890 NodeDefinition::Outputs { .. } => String::from("<outputs>"),
891 NodeDefinition::Missing => String::from("<missing>"),
892 NodeDefinition::Operator(op_def) => {
893 format!("{}_{}", op_def.proto.get_op_type(), op_def.proto.get_name())
894 }
895 d => format!("{}", d.get_name()),
896 }
897 }
898
899 fn traverse(node: Arc<Node>, pairs: &mut Vec<(String, String)>) {
900 let my_name = friendly_name(node.clone());
901 for input in &node.inputs {
902 let source_node_name = friendly_name(input.source_node.clone());
903 pairs.push((source_node_name, my_name.to_string()))
904 }
905
906 for input in &node.inputs {
907 traverse(input.source_node.clone(), pairs);
908 }
909 }
910
911 #[test]
913 pub fn test_optimize_identity_identity() {
914 let _ = env_logger::builder().is_test(true).try_init();
915 pollster::block_on(async {
916 let m = model(graph(
917 vec![tensor("X", &[1])],
918 vec![tensor("Y", &[1])],
919 vec![tensor("A", &[1])],
920 vec![],
921 vec![
922 node(vec!["X"], vec!["A"], "a", "Identity", vec![]),
923 node(vec!["A"], vec!["Y"], "b", "Identity", vec![]),
924 ],
925 ));
926
927 let root = ir::Node::from_model(&m, None).unwrap();
928 let mut opt = Optimizer::new(13);
929 let new_root = opt.optimize(root).await.unwrap();
930 let mut new_pairs = vec![];
931 traverse(new_root, &mut new_pairs);
932 assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string())]);
933 })
934 }
935
936 #[test]
938 pub fn test_optimize_neg_neg() {
939 let _ = env_logger::builder().is_test(true).try_init();
940 pollster::block_on(async {
941 let m = model(graph(
942 vec![tensor("X", &[1])],
943 vec![tensor("Y", &[1])],
944 vec![tensor("A", &[1])],
945 vec![],
946 vec![
947 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
948 node(vec!["A"], vec!["Y"], "b", "Neg", vec![]),
949 ],
950 ));
951
952 let root = ir::Node::from_model(&m, None).unwrap();
953 let mut opt = Optimizer::new(13);
954 let new_root = opt.optimize(root).await.unwrap();
955 let mut new_pairs = vec![];
956 traverse(new_root, &mut new_pairs);
957 assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string())]);
958 });
959 }
960
961 #[test]
963 pub fn test_optimize_3neg() {
964 pollster::block_on(async {
965 let _ = env_logger::builder().is_test(true).try_init();
966
967 let m = model(graph(
968 vec![tensor("X", &[1])],
969 vec![tensor("Y", &[1])],
970 vec![tensor("A", &[1]), tensor("B", &[1])],
971 vec![],
972 vec![
973 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
974 node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
975 node(vec!["B"], vec!["Y"], "c", "Neg", vec![]),
976 ],
977 ));
978
979 let root = ir::Node::from_model(&m, None).unwrap();
980 let mut opt = Optimizer::new(13);
981 let new_root = opt.optimize(root).await.unwrap();
982 let mut new_pairs = vec![];
983 traverse(new_root, &mut new_pairs);
984 assert_eq!(
985 new_pairs,
986 vec![
987 ("Neg_c".to_string(), "<outputs>".to_string()),
988 ("X".to_string(), "Neg_c".to_string())
989 ]
990 );
991 });
992 }
993
994 #[test]
996 pub fn test_optimize_4neg() {
997 let _ = env_logger::builder().is_test(true).try_init();
998 pollster::block_on(async {
999 let m = model(graph(
1000 vec![tensor("X", &[1])],
1001 vec![tensor("Y", &[1])],
1002 vec![tensor("A", &[1]), tensor("B", &[1]), tensor("C", &[1])],
1003 vec![],
1004 vec![
1005 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1006 node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
1007 node(vec!["B"], vec!["C"], "c", "Neg", vec![]),
1008 node(vec!["C"], vec!["Y"], "d", "Neg", vec![]),
1009 ],
1010 ));
1011
1012 let root = ir::Node::from_model(&m, None).unwrap();
1013 let mut opt = Optimizer::new(13);
1014 let new_root = opt.optimize(root).await.unwrap();
1015 let mut new_pairs = vec![];
1016 traverse(new_root, &mut new_pairs);
1017 assert_eq!(new_pairs, vec![("X".to_string(), "<outputs>".to_string()),]);
1018 });
1019 }
1020
1021 #[test]
1023 pub fn test_optimize_5neg() {
1024 let _ = env_logger::builder().is_test(true).try_init();
1025 pollster::block_on(async {
1026 let m = model(graph(
1027 vec![tensor("X", &[1])],
1028 vec![tensor("Y", &[1])],
1029 vec![
1030 tensor("A", &[1]),
1031 tensor("B", &[1]),
1032 tensor("C", &[1]),
1033 tensor("D", &[1]),
1034 ],
1035 vec![],
1036 vec![
1037 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1038 node(vec!["A"], vec!["B"], "b", "Neg", vec![]),
1039 node(vec!["B"], vec!["C"], "c", "Neg", vec![]),
1040 node(vec!["C"], vec!["D"], "d", "Neg", vec![]),
1041 node(vec!["D"], vec!["Y"], "e", "Neg", vec![]),
1042 ],
1043 ));
1044
1045 let root = ir::Node::from_model(&m, None).unwrap();
1046 let mut opt = Optimizer::new(13);
1047 let new_root = opt.optimize(root).await.unwrap();
1048 let mut new_pairs = vec![];
1049 traverse(new_root, &mut new_pairs);
1050 assert_eq!(
1051 new_pairs,
1052 vec![
1053 ("Neg_e".to_string(), "<outputs>".to_string()),
1054 ("X".to_string(), "Neg_e".to_string())
1055 ]
1056 );
1057 });
1058 }
1059
1060 #[test]
1062 pub fn test_optimize_neg_neg_branch() {
1063 let _ = env_logger::builder().is_test(true).try_init();
1064 pollster::block_on(async {
1065 let m = model(graph(
1066 vec![tensor("X", &[1])],
1067 vec![tensor("Y", &[1]), tensor("A", &[1])],
1068 vec![tensor("A", &[1])],
1069 vec![],
1070 vec![
1071 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1072 node(vec!["A"], vec!["Y"], "b", "Neg", vec![]),
1073 ],
1074 ));
1075
1076 let root = ir::Node::from_model(&m, None).unwrap();
1077 let mut opt = Optimizer::new(13);
1078 let new_root = opt.optimize(root).await.unwrap();
1079 let mut new_pairs = vec![];
1080 traverse(new_root, &mut new_pairs);
1081 assert_eq!(
1082 new_pairs,
1083 vec![
1084 ("X".to_string(), "<outputs>".to_string()),
1085 ("Neg_a".to_string(), "<outputs>".to_string()),
1086 ("X".to_string(), "Neg_a".to_string())
1087 ]
1088 );
1089 });
1090 }
1091
1092 #[test]
1094 pub fn test_optimize_identity_identity_two_outputs() {
1095 let _ = env_logger::builder().is_test(true).try_init();
1096
1097 pollster::block_on(async {
1098 let m = model(graph(
1099 vec![tensor("X", &[1])],
1100 vec![tensor("Y", &[1]), tensor("Z", &[1])],
1101 vec![tensor("A", &[1])],
1102 vec![],
1103 vec![
1104 node(vec!["X"], vec!["A"], "a", "Neg", vec![]),
1105 node(vec!["A"], vec!["Z"], "b", "Identity", vec![]),
1106 node(vec!["A"], vec!["Y"], "c", "Identity", vec![]),
1107 ],
1108 ));
1109
1110 let root = ir::Node::from_model(&m, None).unwrap();
1111 let mut opt = Optimizer::new(13);
1112 let new_root = opt.optimize(root).await.unwrap();
1113 let mut new_pairs = vec![];
1114 traverse(new_root, &mut new_pairs);
1115 assert_eq!(
1116 new_pairs,
1117 vec![
1118 ("Neg_a".to_string(), "<outputs>".to_string()),
1119 ("Neg_a".to_string(), "<outputs>".to_string()),
1120 ("X".to_string(), "Neg_a".to_string()),
1121 ("X".to_string(), "Neg_a".to_string()),
1122 ]
1123 );
1124 });
1125 }
1126
1127 #[test]
1129 pub fn test_constant_folding() {
1130 let _ = env_logger::builder().is_test(true).try_init();
1131
1132 pollster::block_on(async {
1133 let m = model(graph(
1134 vec![],
1135 vec![tensor("C", &[1])],
1136 vec![],
1137 vec![
1138 initializer("A", vec![21.0], vec![1]),
1139 initializer("B", vec![7.0], vec![1]),
1140 ],
1141 vec![node(vec!["A", "B"], vec!["C"], "c", "Add", vec![])],
1142 ));
1143
1144 let root = ir::Node::from_model(&m, None).unwrap();
1145 let mut opt = Optimizer::new(13);
1146 let new_root = opt.optimize(root).await.unwrap();
1147 let mut new_pairs = vec![];
1148 traverse(new_root, &mut new_pairs);
1149 assert_eq!(new_pairs, vec![("C".to_string(), "<outputs>".to_string())]);
1150 });
1151 }
1152
1153 #[test]
1155 pub fn test_constant_node_to_tensor() {
1156 let _ = env_logger::builder().is_test(true).try_init();
1157
1158 pollster::block_on(async {
1159 let m = model(graph(
1160 vec![],
1161 vec![tensor("Y", &[1])],
1162 vec![],
1163 vec![],
1164 vec![node(
1165 vec![],
1166 vec!["Y"],
1167 "y",
1168 "Constant",
1169 vec![attribute("value_float", 42.0)],
1170 )],
1171 ));
1172
1173 let root = ir::Node::from_model(&m, None).unwrap();
1174 let mut opt = Optimizer::new(13);
1175 let new_root = opt.optimize(root).await.unwrap();
1176 let mut new_pairs = vec![];
1177 traverse(new_root.clone(), &mut new_pairs);
1178 assert_eq!(new_pairs, vec![("Y".to_string(), "<outputs>".to_string())]);
1179
1180 let y_node = new_root.inputs[0].source_node.clone();
1181 assert!(matches!(y_node.definition(), NodeDefinition::Tensor(_)));
1182 });
1183 }
1184
1185 #[test]
1187 pub fn test_shape_operator() {
1188 test_shape_operator_with(
1189 &[1, 2, 3],
1190 vec![attribute("start", -3), attribute("end", -2)],
1191 &[1],
1192 );
1193 test_shape_operator_with(&[1, 2, 3], vec![], &[1, 2, 3]);
1194 test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 0)], &[3, 4, 5]);
1195 test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 1)], &[4, 5]);
1196 test_shape_operator_with(&[3, 4, 5], vec![attribute("start", -1)], &[5]);
1197 test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 10)], &[3, 4, 5]);
1198 test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 1)], &[3]);
1199 test_shape_operator_with(
1200 &[3, 4, 5],
1201 vec![attribute("start", 10), attribute("end", 10)],
1202 &[],
1203 );
1204 }
1205
1206 pub fn test_shape_operator_with(
1207 input_shape: &[i64],
1208 attrs: Vec<AttributeProto>,
1209 expected: &[i64],
1210 ) {
1211 let _ = env_logger::builder().is_test(true).try_init();
1212
1213 pollster::block_on(async {
1214 let m = model(graph(
1215 vec![tensor("X", input_shape)],
1216 vec![tensor("Y", &[expected.len() as i64])],
1217 vec![],
1218 vec![],
1219 vec![node(vec!["X"], vec!["Y"], "y", "Shape", attrs)],
1220 ));
1221
1222 let root = ir::Node::from_model(&m, None).unwrap();
1223 let mut opt = Optimizer::new(13);
1224 let new_root = opt.optimize(root).await.unwrap();
1225 let mut new_pairs = vec![];
1226 traverse(new_root.clone(), &mut new_pairs);
1227 assert_eq!(new_pairs, vec![("".to_string(), "<outputs>".to_string())]);
1228
1229 let y_node = new_root.inputs[0].source_node.clone();
1230 let NodeDefinition::Tensor(t) = y_node.definition() else {
1231 panic!("should be folded to an initializer");
1232 };
1233 assert_eq!(t.get_int64_data(), expected);
1234 });
1235 }
1236
1237 #[test]
1239 pub fn test_size_operator() {
1240 test_size_operator_with(&[1, 2, 3], &[6]);
1241 test_size_operator_with(&[1], &[1]);
1242 test_size_operator_with(&[], &[1]);
1243 }
1244
1245 pub fn test_size_operator_with(input_shape: &[i64], expected: &[i64]) {
1246 let _ = env_logger::builder().is_test(true).try_init();
1247
1248 pollster::block_on(async {
1249 let m = model(graph(
1250 vec![tensor("X", input_shape)],
1251 vec![tensor("Y", &[expected.len() as i64])],
1252 vec![],
1253 vec![],
1254 vec![node(vec!["X"], vec!["Y"], "y", "Size", vec![])],
1255 ));
1256
1257 let root = ir::Node::from_model(&m, None).unwrap();
1258 let mut opt = Optimizer::new(13);
1259 let new_root = opt.optimize(root).await.unwrap();
1260 let mut new_pairs = vec![];
1261 traverse(new_root.clone(), &mut new_pairs);
1262 assert_eq!(new_pairs, vec![("".to_string(), "<outputs>".to_string())]);
1263
1264 let y_node = new_root.inputs[0].source_node.clone();
1265 let NodeDefinition::Tensor(t) = y_node.definition() else {
1266 panic!("should be folded to an initializer");
1267 };
1268 assert_eq!(t.get_int64_data(), expected);
1269 });
1270 }
1271}