1use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{
6 normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler,
7};
8use crate::protos::onnx::NodeProto;
9use serde_json::Map;
10
11pub struct NormalizationHandler;
12
13impl OpHandler for NormalizationHandler {
14 fn supports(&self, op_type: &str) -> bool {
15 matches!(op_type, "LayerNormalization" | "Softmax")
16 }
17
18 fn convert(
19 &self,
20 node: &NodeProto,
21 context: &ConversionContext,
22 ) -> Result<ConversionResult, OnnxError> {
23 let op_type = node.op_type.as_str();
24 let node_name = if !node.name.is_empty() {
25 node.name.as_str().to_string()
26 } else {
27 "unnamed".to_string()
28 };
29
30 match op_type {
31 "LayerNormalization" => self.convert_layer_norm(node, &node_name, context),
32 "Softmax" => self.convert_softmax(node, &node_name, context),
33 _ => Err(OnnxError::UnsupportedOp {
34 op: op_type.to_string(),
35 node: node_name,
36 }),
37 }
38 }
39}
40
41impl NormalizationHandler {
42 fn convert_layer_norm(
44 &self,
45 node: &NodeProto,
46 node_name: &str,
47 context: &ConversionContext,
48 ) -> Result<ConversionResult, OnnxError> {
49 let inputs = node.input.as_slice();
50 if inputs.is_empty() {
51 return Err(OnnxError::InvalidShape(
52 "LayerNormalization expects at least 1 input".to_string(),
53 ));
54 }
55
56 let mut epsilon = 1e-5f32;
58 let mut axis = -1i64;
59
60 for attr in node.attribute.as_slice() {
61 match attr.name.as_str() {
62 "epsilon" if attr.f != 0.0 => {
63 epsilon = attr.f;
64 }
65 "axis" if attr.i != 0 => {
66 axis = attr.i;
67 }
68 _ => {}
69 }
70 }
71
72 let output_name = if node.output.as_slice().is_empty() {
73 format!("{}_output", node_name)
74 } else {
75 sanitize_identifier(&node.output.as_slice()[0].to_string())
76 };
77
78 let mut options = Map::new();
79 options.insert("epsilon".to_string(), serde_json::json!(epsilon));
80
81 if let Some(rank) = context.input_rank(inputs[0].as_str()) {
83 let normalized_axis = normalize_axis_best_effort(axis, rank);
84 options.insert("axes".to_string(), serde_json::json!([normalized_axis]));
85 } else if axis != -1 {
86 options.insert("axes".to_string(), serde_json::json!([axis]));
87 }
88
89 let webnn_inputs = if inputs.len() >= 3 {
91 let input0 = context.resolve_input(&inputs[0]);
93 let input1 = context.resolve_input(&inputs[1]);
94 let input2 = context.resolve_input(&inputs[2]);
95 vec![input0, input1, input2]
96 } else if inputs.len() == 2 {
97 let input0 = context.resolve_input(&inputs[0]);
99 let input1 = context.resolve_input(&inputs[1]);
100 vec![input0, input1]
101 } else {
102 let input0 = context.resolve_input(&inputs[0]);
104 vec![input0]
105 };
106
107 let mut result = ConversionResult::new(vec![Node {
108 id: output_name.clone(),
109 op: "layerNormalization".to_string(),
110 inputs: webnn_inputs,
111 options,
112 outputs: None,
113 }]);
114
115 if let Some(output) = node.output.as_slice().first() {
116 result
117 .output_mappings
118 .insert(output.to_string(), output_name.clone());
119 }
120
121 Ok(result)
122 }
123
124 fn convert_softmax(
126 &self,
127 node: &NodeProto,
128 node_name: &str,
129 context: &ConversionContext,
130 ) -> Result<ConversionResult, OnnxError> {
131 let inputs = node.input.as_slice();
132 if inputs.len() != 1 {
133 return Err(OnnxError::InvalidShape(format!(
134 "Softmax expects 1 input, got {}",
135 inputs.len()
136 )));
137 }
138
139 let mut axis = -1i64;
141 for attr in node.attribute.as_slice() {
142 if attr.name.as_str() == "axis" && attr.i != 0 {
143 axis = attr.i;
144 }
145 }
146
147 let output_name = if node.output.as_slice().is_empty() {
148 format!("{}_output", node_name)
149 } else {
150 sanitize_identifier(&node.output.as_slice()[0].to_string())
151 };
152
153 let input0 = context.resolve_input(&inputs[0]);
154
155 let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
156 normalize_axis_best_effort(axis, rank)
157 } else {
158 axis
159 };
160
161 let mut options = Map::new();
162 options.insert("axis".to_string(), serde_json::json!(axis));
163
164 let mut result = ConversionResult::new(vec![Node {
165 id: output_name.clone(),
166 op: "softmax".to_string(),
167 inputs: vec![input0],
168 options,
169 outputs: None,
170 }]);
171
172 if let Some(output) = node.output.as_slice().first() {
173 result
174 .output_mappings
175 .insert(output.to_string(), output_name.clone());
176 }
177
178 Ok(result)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::protos::onnx::{AttributeProto, NodeProto};
186
187 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
188 NodeProto {
189 op_type: op_type.to_string(),
190 name: format!("test_{}", op_type.to_lowercase()),
191 input: inputs.iter().map(|s| s.to_string()).collect(),
192 output: outputs.iter().map(|s| s.to_string()).collect(),
193 ..Default::default()
194 }
195 }
196
197 fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
198 let attr = AttributeProto {
199 name: name.to_string(),
200 i: value,
201 ..Default::default()
202 };
203 node.attribute.push(attr);
204 }
205
206 #[test]
207 fn test_normalization_handler_supports() {
208 let handler = NormalizationHandler;
209 assert!(handler.supports("LayerNormalization"));
210 assert!(handler.supports("Softmax"));
211 assert!(!handler.supports("Add"));
212 }
213
214 #[test]
215 fn test_convert_softmax() {
216 let handler = NormalizationHandler;
217 let mut node = create_test_node("Softmax", vec!["x"], vec!["y"]);
218 add_int_attribute(&mut node, "axis", -1);
219 let initializers = std::collections::HashMap::new();
220 let mut value_shapes = std::collections::HashMap::new();
221 value_shapes.insert("x".to_string(), vec![1, 128, 384]);
222 let const_values = std::collections::HashMap::new();
223 let value_ids = std::collections::HashMap::new();
224 let value_types = std::collections::HashMap::new();
225 let context = ConversionContext {
226 initializers: &initializers,
227 value_shapes: &value_shapes,
228 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
229 const_values: &const_values,
230 value_ids: &value_ids,
231 value_types: &value_types,
232 };
233
234 let result = handler.convert(&node, &context).unwrap();
235 assert_eq!(result.nodes.len(), 1);
236 assert_eq!(result.nodes[0].op, "softmax");
237 assert_eq!(result.nodes[0].inputs, vec!["x"]);
238 assert_eq!(result.nodes[0].id, "y");
239 assert!(result.nodes[0].options.contains_key("axis"));
240 assert_eq!(
241 result.nodes[0].options.get("axis"),
242 Some(&serde_json::json!(2))
243 );
244 }
245
246 #[test]
247 fn test_convert_layer_norm() {
248 let handler = NormalizationHandler;
249 let mut node =
250 create_test_node("LayerNormalization", vec!["x", "scale", "bias"], vec!["y"]);
251 add_int_attribute(&mut node, "axis", -1);
252 let initializers = std::collections::HashMap::new();
253 let mut value_shapes = std::collections::HashMap::new();
254 value_shapes.insert("x".to_string(), vec![1, 128, 384]);
255 let const_values = std::collections::HashMap::new();
256 let value_ids = std::collections::HashMap::new();
257 let value_types = std::collections::HashMap::new();
258 let context = ConversionContext {
259 initializers: &initializers,
260 value_shapes: &value_shapes,
261 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
262 const_values: &const_values,
263 value_ids: &value_ids,
264 value_types: &value_types,
265 };
266
267 let result = handler.convert(&node, &context).unwrap();
268 assert_eq!(result.nodes.len(), 1);
269 assert_eq!(result.nodes[0].op, "layerNormalization");
270 assert_eq!(result.nodes[0].inputs.len(), 3);
271 assert!(result.nodes[0].options.contains_key("epsilon"));
272 assert_eq!(
273 result.nodes[0].options.get("axes"),
274 Some(&serde_json::json!([2]))
275 );
276 }
277}