Skip to main content

webnn_graph/onnx/ops/
reduction.rs

1// Reduction operators: ReduceMean, ReduceSum, ReduceMax, ReduceMin
2
3use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{
6    normalize_axes_best_effort, ConversionContext, ConversionResult, OpHandler,
7};
8use crate::protos::onnx::NodeProto;
9use serde_json::Map;
10
11pub struct ReductionHandler;
12
13impl OpHandler for ReductionHandler {
14    fn supports(&self, op_type: &str) -> bool {
15        matches!(
16            op_type,
17            "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin"
18        )
19    }
20
21    fn convert(
22        &self,
23        node: &NodeProto,
24        context: &ConversionContext,
25    ) -> Result<ConversionResult, OnnxError> {
26        let op_type = node.op_type.as_str();
27        let node_name = if !node.name.is_empty() {
28            node.name.as_str().to_string()
29        } else {
30            "unnamed".to_string()
31        };
32
33        match op_type {
34            "ReduceMean" => self.convert_reduce(node, &node_name, "reduceMean", context),
35            "ReduceSum" => self.convert_reduce(node, &node_name, "reduceSum", context),
36            "ReduceMax" => self.convert_reduce(node, &node_name, "reduceMax", context),
37            "ReduceMin" => self.convert_reduce(node, &node_name, "reduceMin", context),
38            _ => Err(OnnxError::UnsupportedOp {
39                op: op_type.to_string(),
40                node: node_name,
41            }),
42        }
43    }
44}
45
46impl ReductionHandler {
47    /// Convert ONNX reduce operations to WebNN reduce operations
48    fn convert_reduce(
49        &self,
50        node: &NodeProto,
51        node_name: &str,
52        webnn_op: &str,
53        context: &ConversionContext,
54    ) -> Result<ConversionResult, OnnxError> {
55        let inputs = node.input.as_slice();
56        if inputs.is_empty() {
57            return Err(OnnxError::InvalidShape(format!(
58                "{} expects at least 1 input",
59                webnn_op
60            )));
61        }
62
63        // Extract attributes
64        let mut axes: Option<Vec<i64>> = None;
65        let mut keepdims = 1i64; // ONNX default is 1 (keep dimensions)
66
67        for attr in node.attribute.as_slice() {
68            match attr.name.as_str() {
69                "axes" => {
70                    axes = Some(attr.ints.clone());
71                }
72                "keepdims" if attr.i != 0 => {
73                    keepdims = attr.i;
74                }
75                _ => {}
76            }
77        }
78
79        let output_name = if node.output.as_slice().is_empty() {
80            format!("{}_output", node_name)
81        } else {
82            sanitize_identifier(&node.output.as_slice()[0].to_string())
83        };
84
85        let input0 = context.resolve_input(&inputs[0]);
86
87        let mut options = Map::new();
88
89        // Add axes if specified
90        if let Some(axes_values) = axes {
91            let axes_values = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
92                normalize_axes_best_effort(&axes_values, rank)
93            } else {
94                axes_values
95            };
96            options.insert("axes".to_string(), serde_json::json!(axes_values));
97        }
98
99        // Add keepDims option (WebNN uses keepDimensions)
100        options.insert(
101            "keepDimensions".to_string(),
102            serde_json::json!(keepdims != 0),
103        );
104
105        let mut result = ConversionResult::new(vec![Node {
106            id: output_name.clone(),
107            op: webnn_op.to_string(),
108            inputs: vec![input0],
109            options,
110            outputs: None,
111        }]);
112
113        if let Some(output) = node.output.as_slice().first() {
114            result
115                .output_mappings
116                .insert(output.to_string(), output_name.clone());
117        }
118
119        Ok(result)
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::protos::onnx::{AttributeProto, NodeProto};
127
128    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
129        NodeProto {
130            op_type: op_type.to_string(),
131            name: format!("test_{}", op_type.to_lowercase()),
132            input: inputs.iter().map(|s| s.to_string()).collect(),
133            output: outputs.iter().map(|s| s.to_string()).collect(),
134            ..Default::default()
135        }
136    }
137
138    fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
139        let attr = AttributeProto {
140            name: name.to_string(),
141            i: value,
142            ..Default::default()
143        };
144        node.attribute.push(attr);
145    }
146
147    fn add_ints_attribute(node: &mut NodeProto, name: &str, values: Vec<i64>) {
148        let attr = AttributeProto {
149            name: name.to_string(),
150            ints: values,
151            ..Default::default()
152        };
153        node.attribute.push(attr);
154    }
155
156    #[test]
157    fn test_reduction_handler_supports() {
158        let handler = ReductionHandler;
159        assert!(handler.supports("ReduceMean"));
160        assert!(handler.supports("ReduceSum"));
161        assert!(handler.supports("ReduceMax"));
162        assert!(handler.supports("ReduceMin"));
163        assert!(!handler.supports("Add"));
164    }
165
166    #[test]
167    fn test_convert_reduce_mean() {
168        let handler = ReductionHandler;
169        let mut node = create_test_node("ReduceMean", vec!["x"], vec!["y"]);
170        add_ints_attribute(&mut node, "axes", vec![1, 2]);
171        add_int_attribute(&mut node, "keepdims", 1);
172        let initializers = std::collections::HashMap::new();
173        let value_shapes = std::collections::HashMap::new();
174        let const_values = std::collections::HashMap::new();
175        let value_ids = std::collections::HashMap::new();
176        let value_types = std::collections::HashMap::new();
177        let context = ConversionContext {
178            initializers: &initializers,
179            value_shapes: &value_shapes,
180            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
181            const_values: &const_values,
182            value_ids: &value_ids,
183            value_types: &value_types,
184        };
185
186        let result = handler.convert(&node, &context).unwrap();
187        assert_eq!(result.nodes.len(), 1);
188        assert_eq!(result.nodes[0].op, "reduceMean");
189        assert_eq!(result.nodes[0].inputs, vec!["x"]);
190        assert!(result.nodes[0].options.contains_key("axes"));
191        assert!(result.nodes[0].options.contains_key("keepDimensions"));
192    }
193
194    #[test]
195    fn test_convert_reduce_sum() {
196        let handler = ReductionHandler;
197        let mut node = create_test_node("ReduceSum", vec!["x"], vec!["y"]);
198        add_ints_attribute(&mut node, "axes", vec![-1]);
199        let initializers = std::collections::HashMap::new();
200        let mut value_shapes = std::collections::HashMap::new();
201        value_shapes.insert("x".to_string(), vec![2, 3, 4]);
202        let const_values = std::collections::HashMap::new();
203        let value_ids = std::collections::HashMap::new();
204        let value_types = std::collections::HashMap::new();
205        let context = ConversionContext {
206            initializers: &initializers,
207            value_shapes: &value_shapes,
208            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
209            const_values: &const_values,
210            value_ids: &value_ids,
211            value_types: &value_types,
212        };
213
214        let result = handler.convert(&node, &context).unwrap();
215        assert_eq!(result.nodes.len(), 1);
216        assert_eq!(result.nodes[0].op, "reduceSum");
217        assert_eq!(
218            result.nodes[0].options.get("axes"),
219            Some(&serde_json::json!([2]))
220        );
221    }
222}