Skip to main content

webnn_graph/onnx/ops/
elementwise.rs

1// Elementwise binary operators: Add, Sub, Mul, Div, Pow
2
3use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ElementwiseHandler;
10
11impl OpHandler for ElementwiseHandler {
12    fn supports(&self, op_type: &str) -> bool {
13        matches!(
14            op_type,
15            "Add" | "Sub" | "Mul" | "Div" | "Pow" | "Min" | "Max"
16        )
17    }
18
19    fn convert(
20        &self,
21        node: &NodeProto,
22        context: &ConversionContext,
23    ) -> Result<ConversionResult, OnnxError> {
24        let op_type = node.op_type.as_str();
25        let node_name = if !node.name.is_empty() {
26            node.name.as_str().to_string()
27        } else {
28            "unnamed".to_string()
29        };
30
31        let inputs = node.input.as_slice();
32        if inputs.len() != 2 {
33            return Err(OnnxError::InvalidShape(format!(
34                "{} expects 2 inputs, got {}",
35                op_type,
36                inputs.len()
37            )));
38        }
39
40        let output_name = if node.output.as_slice().is_empty() {
41            format!("{}_output", node_name)
42        } else {
43            sanitize_identifier(&node.output.as_slice()[0].to_string())
44        };
45
46        // Resolve input names (respecting prior mappings)
47        let input0 = context.resolve_input(&inputs[0]);
48        let input1 = context.resolve_input(&inputs[1]);
49
50        // Map ONNX operator to WebNN operator (lowercase)
51        let webnn_op = match op_type {
52            "Add" => "add",
53            "Sub" => "sub",
54            "Mul" => "mul",
55            "Div" => "div",
56            "Pow" => "pow",
57            "Min" => "min",
58            "Max" => "max",
59            _ => {
60                return Err(OnnxError::UnsupportedOp {
61                    op: op_type.to_string(),
62                    node: node_name,
63                })
64            }
65        };
66
67        let mut result = ConversionResult::new(vec![Node {
68            id: output_name.clone(),
69            op: webnn_op.to_string(),
70            inputs: vec![input0, input1],
71            options: Map::new(),
72            outputs: None,
73        }]);
74
75        if let Some(output) = node.output.as_slice().first() {
76            result
77                .output_mappings
78                .insert(output.to_string(), output_name.clone());
79        }
80
81        Ok(result)
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use crate::protos::onnx::NodeProto;
89
90    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
91        NodeProto {
92            op_type: op_type.to_string(),
93            name: format!("test_{}", op_type.to_lowercase()),
94            input: inputs.iter().map(|s| s.to_string()).collect(),
95            output: outputs.iter().map(|s| s.to_string()).collect(),
96            ..Default::default()
97        }
98    }
99
100    #[test]
101    fn test_elementwise_handler_supports() {
102        let handler = ElementwiseHandler;
103        assert!(handler.supports("Add"));
104        assert!(handler.supports("Sub"));
105        assert!(handler.supports("Mul"));
106        assert!(handler.supports("Div"));
107        assert!(handler.supports("Pow"));
108        assert!(handler.supports("Min"));
109        assert!(handler.supports("Max"));
110        assert!(!handler.supports("MatMul"));
111    }
112
113    #[test]
114    fn test_convert_add() {
115        let handler = ElementwiseHandler;
116        let node = create_test_node("Add", vec!["a", "b"], vec!["c"]);
117        let initializers = std::collections::HashMap::new();
118        let value_shapes = std::collections::HashMap::new();
119        let const_values = std::collections::HashMap::new();
120        let value_ids = std::collections::HashMap::new();
121        let value_types = std::collections::HashMap::new();
122        let context = ConversionContext {
123            initializers: &initializers,
124            value_shapes: &value_shapes,
125            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
126            const_values: &const_values,
127            value_ids: &value_ids,
128            value_types: &value_types,
129        };
130
131        let result = handler.convert(&node, &context).unwrap();
132        assert_eq!(result.nodes.len(), 1);
133        assert_eq!(result.nodes[0].op, "add");
134        assert_eq!(result.nodes[0].inputs, vec!["a", "b"]);
135        assert_eq!(result.nodes[0].id, "c");
136    }
137
138    #[test]
139    fn test_convert_mul() {
140        let handler = ElementwiseHandler;
141        let node = create_test_node("Mul", vec!["x", "y"], vec!["z"]);
142        let initializers = std::collections::HashMap::new();
143        let value_shapes = std::collections::HashMap::new();
144        let const_values = std::collections::HashMap::new();
145        let value_ids = std::collections::HashMap::new();
146        let value_types = std::collections::HashMap::new();
147        let context = ConversionContext {
148            initializers: &initializers,
149            value_shapes: &value_shapes,
150            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
151            const_values: &const_values,
152            value_ids: &value_ids,
153            value_types: &value_types,
154        };
155
156        let result = handler.convert(&node, &context).unwrap();
157        assert_eq!(result.nodes.len(), 1);
158        assert_eq!(result.nodes[0].op, "mul");
159        assert_eq!(result.nodes[0].inputs, vec!["x", "y"]);
160    }
161
162    #[test]
163    fn test_convert_div() {
164        let handler = ElementwiseHandler;
165        let node = create_test_node("Div", vec!["a", "b"], vec!["c"]);
166        let initializers = std::collections::HashMap::new();
167        let value_shapes = std::collections::HashMap::new();
168        let const_values = std::collections::HashMap::new();
169        let value_ids = std::collections::HashMap::new();
170        let value_types = std::collections::HashMap::new();
171        let context = ConversionContext {
172            initializers: &initializers,
173            value_shapes: &value_shapes,
174            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
175            const_values: &const_values,
176            value_ids: &value_ids,
177            value_types: &value_types,
178        };
179
180        let result = handler.convert(&node, &context).unwrap();
181        assert_eq!(result.nodes.len(), 1);
182        assert_eq!(result.nodes[0].op, "div");
183    }
184
185    #[test]
186    fn test_convert_min() {
187        let handler = ElementwiseHandler;
188        let node = create_test_node("Min", vec!["x", "y"], vec!["z"]);
189        let initializers = std::collections::HashMap::new();
190        let value_shapes = std::collections::HashMap::new();
191        let const_values = std::collections::HashMap::new();
192        let value_ids = std::collections::HashMap::new();
193        let value_types = std::collections::HashMap::new();
194        let context = ConversionContext {
195            initializers: &initializers,
196            value_shapes: &value_shapes,
197            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
198            const_values: &const_values,
199            value_ids: &value_ids,
200            value_types: &value_types,
201        };
202
203        let result = handler.convert(&node, &context).unwrap();
204        assert_eq!(result.nodes.len(), 1);
205        assert_eq!(result.nodes[0].op, "min");
206        assert_eq!(result.nodes[0].inputs, vec!["x", "y"]);
207    }
208
209    #[test]
210    fn test_convert_max() {
211        let handler = ElementwiseHandler;
212        let node = create_test_node("Max", vec!["a", "b"], vec!["c"]);
213        let initializers = std::collections::HashMap::new();
214        let value_shapes = std::collections::HashMap::new();
215        let const_values = std::collections::HashMap::new();
216        let value_ids = std::collections::HashMap::new();
217        let value_types = std::collections::HashMap::new();
218        let context = ConversionContext {
219            initializers: &initializers,
220            value_shapes: &value_shapes,
221            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
222            const_values: &const_values,
223            value_ids: &value_ids,
224            value_types: &value_types,
225        };
226
227        let result = handler.convert(&node, &context).unwrap();
228        assert_eq!(result.nodes.len(), 1);
229        assert_eq!(result.nodes[0].op, "max");
230        assert_eq!(result.nodes[0].inputs, vec!["a", "b"]);
231    }
232}