1use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ElementwiseHandler;
10
11impl OpHandler for ElementwiseHandler {
12 fn supports(&self, op_type: &str) -> bool {
13 matches!(
14 op_type,
15 "Add" | "Sub" | "Mul" | "Div" | "Pow" | "Min" | "Max"
16 )
17 }
18
19 fn convert(
20 &self,
21 node: &NodeProto,
22 context: &ConversionContext,
23 ) -> Result<ConversionResult, OnnxError> {
24 let op_type = node.op_type.as_str();
25 let node_name = if !node.name.is_empty() {
26 node.name.as_str().to_string()
27 } else {
28 "unnamed".to_string()
29 };
30
31 let inputs = node.input.as_slice();
32 if inputs.len() != 2 {
33 return Err(OnnxError::InvalidShape(format!(
34 "{} expects 2 inputs, got {}",
35 op_type,
36 inputs.len()
37 )));
38 }
39
40 let output_name = if node.output.as_slice().is_empty() {
41 format!("{}_output", node_name)
42 } else {
43 sanitize_identifier(&node.output.as_slice()[0].to_string())
44 };
45
46 let input0 = context.resolve_input(&inputs[0]);
48 let input1 = context.resolve_input(&inputs[1]);
49
50 let webnn_op = match op_type {
52 "Add" => "add",
53 "Sub" => "sub",
54 "Mul" => "mul",
55 "Div" => "div",
56 "Pow" => "pow",
57 "Min" => "min",
58 "Max" => "max",
59 _ => {
60 return Err(OnnxError::UnsupportedOp {
61 op: op_type.to_string(),
62 node: node_name,
63 })
64 }
65 };
66
67 let mut result = ConversionResult::new(vec![Node {
68 id: output_name.clone(),
69 op: webnn_op.to_string(),
70 inputs: vec![input0, input1],
71 options: Map::new(),
72 outputs: None,
73 }]);
74
75 if let Some(output) = node.output.as_slice().first() {
76 result
77 .output_mappings
78 .insert(output.to_string(), output_name.clone());
79 }
80
81 Ok(result)
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use crate::protos::onnx::NodeProto;
89
90 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
91 NodeProto {
92 op_type: op_type.to_string(),
93 name: format!("test_{}", op_type.to_lowercase()),
94 input: inputs.iter().map(|s| s.to_string()).collect(),
95 output: outputs.iter().map(|s| s.to_string()).collect(),
96 ..Default::default()
97 }
98 }
99
100 #[test]
101 fn test_elementwise_handler_supports() {
102 let handler = ElementwiseHandler;
103 assert!(handler.supports("Add"));
104 assert!(handler.supports("Sub"));
105 assert!(handler.supports("Mul"));
106 assert!(handler.supports("Div"));
107 assert!(handler.supports("Pow"));
108 assert!(handler.supports("Min"));
109 assert!(handler.supports("Max"));
110 assert!(!handler.supports("MatMul"));
111 }
112
113 #[test]
114 fn test_convert_add() {
115 let handler = ElementwiseHandler;
116 let node = create_test_node("Add", vec!["a", "b"], vec!["c"]);
117 let initializers = std::collections::HashMap::new();
118 let value_shapes = std::collections::HashMap::new();
119 let const_values = std::collections::HashMap::new();
120 let value_ids = std::collections::HashMap::new();
121 let value_types = std::collections::HashMap::new();
122 let context = ConversionContext {
123 initializers: &initializers,
124 value_shapes: &value_shapes,
125 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
126 const_values: &const_values,
127 value_ids: &value_ids,
128 value_types: &value_types,
129 };
130
131 let result = handler.convert(&node, &context).unwrap();
132 assert_eq!(result.nodes.len(), 1);
133 assert_eq!(result.nodes[0].op, "add");
134 assert_eq!(result.nodes[0].inputs, vec!["a", "b"]);
135 assert_eq!(result.nodes[0].id, "c");
136 }
137
138 #[test]
139 fn test_convert_mul() {
140 let handler = ElementwiseHandler;
141 let node = create_test_node("Mul", vec!["x", "y"], vec!["z"]);
142 let initializers = std::collections::HashMap::new();
143 let value_shapes = std::collections::HashMap::new();
144 let const_values = std::collections::HashMap::new();
145 let value_ids = std::collections::HashMap::new();
146 let value_types = std::collections::HashMap::new();
147 let context = ConversionContext {
148 initializers: &initializers,
149 value_shapes: &value_shapes,
150 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
151 const_values: &const_values,
152 value_ids: &value_ids,
153 value_types: &value_types,
154 };
155
156 let result = handler.convert(&node, &context).unwrap();
157 assert_eq!(result.nodes.len(), 1);
158 assert_eq!(result.nodes[0].op, "mul");
159 assert_eq!(result.nodes[0].inputs, vec!["x", "y"]);
160 }
161
162 #[test]
163 fn test_convert_div() {
164 let handler = ElementwiseHandler;
165 let node = create_test_node("Div", vec!["a", "b"], vec!["c"]);
166 let initializers = std::collections::HashMap::new();
167 let value_shapes = std::collections::HashMap::new();
168 let const_values = std::collections::HashMap::new();
169 let value_ids = std::collections::HashMap::new();
170 let value_types = std::collections::HashMap::new();
171 let context = ConversionContext {
172 initializers: &initializers,
173 value_shapes: &value_shapes,
174 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
175 const_values: &const_values,
176 value_ids: &value_ids,
177 value_types: &value_types,
178 };
179
180 let result = handler.convert(&node, &context).unwrap();
181 assert_eq!(result.nodes.len(), 1);
182 assert_eq!(result.nodes[0].op, "div");
183 }
184
185 #[test]
186 fn test_convert_min() {
187 let handler = ElementwiseHandler;
188 let node = create_test_node("Min", vec!["x", "y"], vec!["z"]);
189 let initializers = std::collections::HashMap::new();
190 let value_shapes = std::collections::HashMap::new();
191 let const_values = std::collections::HashMap::new();
192 let value_ids = std::collections::HashMap::new();
193 let value_types = std::collections::HashMap::new();
194 let context = ConversionContext {
195 initializers: &initializers,
196 value_shapes: &value_shapes,
197 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
198 const_values: &const_values,
199 value_ids: &value_ids,
200 value_types: &value_types,
201 };
202
203 let result = handler.convert(&node, &context).unwrap();
204 assert_eq!(result.nodes.len(), 1);
205 assert_eq!(result.nodes[0].op, "min");
206 assert_eq!(result.nodes[0].inputs, vec!["x", "y"]);
207 }
208
209 #[test]
210 fn test_convert_max() {
211 let handler = ElementwiseHandler;
212 let node = create_test_node("Max", vec!["a", "b"], vec!["c"]);
213 let initializers = std::collections::HashMap::new();
214 let value_shapes = std::collections::HashMap::new();
215 let const_values = std::collections::HashMap::new();
216 let value_ids = std::collections::HashMap::new();
217 let value_types = std::collections::HashMap::new();
218 let context = ConversionContext {
219 initializers: &initializers,
220 value_shapes: &value_shapes,
221 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
222 const_values: &const_values,
223 value_ids: &value_ids,
224 value_types: &value_types,
225 };
226
227 let result = handler.convert(&node, &context).unwrap();
228 assert_eq!(result.nodes.len(), 1);
229 assert_eq!(result.nodes[0].op, "max");
230 assert_eq!(result.nodes[0].inputs, vec!["a", "b"]);
231 }
232}