1use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{
6 normalize_axes_best_effort, ConversionContext, ConversionResult, OpHandler,
7};
8use crate::protos::onnx::NodeProto;
9use serde_json::Map;
10
11pub struct ReductionHandler;
12
13impl OpHandler for ReductionHandler {
14 fn supports(&self, op_type: &str) -> bool {
15 matches!(
16 op_type,
17 "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin"
18 )
19 }
20
21 fn convert(
22 &self,
23 node: &NodeProto,
24 context: &ConversionContext,
25 ) -> Result<ConversionResult, OnnxError> {
26 let op_type = node.op_type.as_str();
27 let node_name = if !node.name.is_empty() {
28 node.name.as_str().to_string()
29 } else {
30 "unnamed".to_string()
31 };
32
33 match op_type {
34 "ReduceMean" => self.convert_reduce(node, &node_name, "reduceMean", context),
35 "ReduceSum" => self.convert_reduce(node, &node_name, "reduceSum", context),
36 "ReduceMax" => self.convert_reduce(node, &node_name, "reduceMax", context),
37 "ReduceMin" => self.convert_reduce(node, &node_name, "reduceMin", context),
38 _ => Err(OnnxError::UnsupportedOp {
39 op: op_type.to_string(),
40 node: node_name,
41 }),
42 }
43 }
44}
45
46impl ReductionHandler {
47 fn convert_reduce(
49 &self,
50 node: &NodeProto,
51 node_name: &str,
52 webnn_op: &str,
53 context: &ConversionContext,
54 ) -> Result<ConversionResult, OnnxError> {
55 let inputs = node.input.as_slice();
56 if inputs.is_empty() {
57 return Err(OnnxError::InvalidShape(format!(
58 "{} expects at least 1 input",
59 webnn_op
60 )));
61 }
62
63 let mut axes: Option<Vec<i64>> = None;
65 let mut keepdims = 1i64; for attr in node.attribute.as_slice() {
68 match attr.name.as_str() {
69 "axes" => {
70 axes = Some(attr.ints.clone());
71 }
72 "keepdims" if attr.i != 0 => {
73 keepdims = attr.i;
74 }
75 _ => {}
76 }
77 }
78
79 let output_name = if node.output.as_slice().is_empty() {
80 format!("{}_output", node_name)
81 } else {
82 sanitize_identifier(&node.output.as_slice()[0].to_string())
83 };
84
85 let input0 = context.resolve_input(&inputs[0]);
86
87 let mut options = Map::new();
88
89 if let Some(axes_values) = axes {
91 let axes_values = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
92 normalize_axes_best_effort(&axes_values, rank)
93 } else {
94 axes_values
95 };
96 options.insert("axes".to_string(), serde_json::json!(axes_values));
97 }
98
99 options.insert(
101 "keepDimensions".to_string(),
102 serde_json::json!(keepdims != 0),
103 );
104
105 let mut result = ConversionResult::new(vec![Node {
106 id: output_name.clone(),
107 op: webnn_op.to_string(),
108 inputs: vec![input0],
109 options,
110 outputs: None,
111 }]);
112
113 if let Some(output) = node.output.as_slice().first() {
114 result
115 .output_mappings
116 .insert(output.to_string(), output_name.clone());
117 }
118
119 Ok(result)
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::protos::onnx::{AttributeProto, NodeProto};
127
128 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
129 NodeProto {
130 op_type: op_type.to_string(),
131 name: format!("test_{}", op_type.to_lowercase()),
132 input: inputs.iter().map(|s| s.to_string()).collect(),
133 output: outputs.iter().map(|s| s.to_string()).collect(),
134 ..Default::default()
135 }
136 }
137
138 fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
139 let attr = AttributeProto {
140 name: name.to_string(),
141 i: value,
142 ..Default::default()
143 };
144 node.attribute.push(attr);
145 }
146
147 fn add_ints_attribute(node: &mut NodeProto, name: &str, values: Vec<i64>) {
148 let attr = AttributeProto {
149 name: name.to_string(),
150 ints: values,
151 ..Default::default()
152 };
153 node.attribute.push(attr);
154 }
155
156 #[test]
157 fn test_reduction_handler_supports() {
158 let handler = ReductionHandler;
159 assert!(handler.supports("ReduceMean"));
160 assert!(handler.supports("ReduceSum"));
161 assert!(handler.supports("ReduceMax"));
162 assert!(handler.supports("ReduceMin"));
163 assert!(!handler.supports("Add"));
164 }
165
166 #[test]
167 fn test_convert_reduce_mean() {
168 let handler = ReductionHandler;
169 let mut node = create_test_node("ReduceMean", vec!["x"], vec!["y"]);
170 add_ints_attribute(&mut node, "axes", vec![1, 2]);
171 add_int_attribute(&mut node, "keepdims", 1);
172 let initializers = std::collections::HashMap::new();
173 let value_shapes = std::collections::HashMap::new();
174 let const_values = std::collections::HashMap::new();
175 let value_ids = std::collections::HashMap::new();
176 let value_types = std::collections::HashMap::new();
177 let context = ConversionContext {
178 initializers: &initializers,
179 value_shapes: &value_shapes,
180 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
181 const_values: &const_values,
182 value_ids: &value_ids,
183 value_types: &value_types,
184 };
185
186 let result = handler.convert(&node, &context).unwrap();
187 assert_eq!(result.nodes.len(), 1);
188 assert_eq!(result.nodes[0].op, "reduceMean");
189 assert_eq!(result.nodes[0].inputs, vec!["x"]);
190 assert!(result.nodes[0].options.contains_key("axes"));
191 assert!(result.nodes[0].options.contains_key("keepDimensions"));
192 }
193
194 #[test]
195 fn test_convert_reduce_sum() {
196 let handler = ReductionHandler;
197 let mut node = create_test_node("ReduceSum", vec!["x"], vec!["y"]);
198 add_ints_attribute(&mut node, "axes", vec![-1]);
199 let initializers = std::collections::HashMap::new();
200 let mut value_shapes = std::collections::HashMap::new();
201 value_shapes.insert("x".to_string(), vec![2, 3, 4]);
202 let const_values = std::collections::HashMap::new();
203 let value_ids = std::collections::HashMap::new();
204 let value_types = std::collections::HashMap::new();
205 let context = ConversionContext {
206 initializers: &initializers,
207 value_shapes: &value_shapes,
208 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
209 const_values: &const_values,
210 value_ids: &value_ids,
211 value_types: &value_types,
212 };
213
214 let result = handler.convert(&node, &context).unwrap();
215 assert_eq!(result.nodes.len(), 1);
216 assert_eq!(result.nodes[0].op, "reduceSum");
217 assert_eq!(
218 result.nodes[0].options.get("axes"),
219 Some(&serde_json::json!([2]))
220 );
221 }
222}