Skip to main content

webnn_graph/onnx/ops/
normalization.rs

1// Normalization operators: LayerNormalization, Softmax
2
3use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{
6    normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler,
7};
8use crate::protos::onnx::NodeProto;
9use serde_json::Map;
10
11pub struct NormalizationHandler;
12
13impl OpHandler for NormalizationHandler {
14    fn supports(&self, op_type: &str) -> bool {
15        matches!(op_type, "LayerNormalization" | "Softmax")
16    }
17
18    fn convert(
19        &self,
20        node: &NodeProto,
21        context: &ConversionContext,
22    ) -> Result<ConversionResult, OnnxError> {
23        let op_type = node.op_type.as_str();
24        let node_name = if !node.name.is_empty() {
25            node.name.as_str().to_string()
26        } else {
27            "unnamed".to_string()
28        };
29
30        match op_type {
31            "LayerNormalization" => self.convert_layer_norm(node, &node_name, context),
32            "Softmax" => self.convert_softmax(node, &node_name, context),
33            _ => Err(OnnxError::UnsupportedOp {
34                op: op_type.to_string(),
35                node: node_name,
36            }),
37        }
38    }
39}
40
41impl NormalizationHandler {
42    /// Convert ONNX LayerNormalization to WebNN layerNormalization
43    fn convert_layer_norm(
44        &self,
45        node: &NodeProto,
46        node_name: &str,
47        context: &ConversionContext,
48    ) -> Result<ConversionResult, OnnxError> {
49        let inputs = node.input.as_slice();
50        if inputs.is_empty() {
51            return Err(OnnxError::InvalidShape(
52                "LayerNormalization expects at least 1 input".to_string(),
53            ));
54        }
55
56        // Extract attributes
57        let mut epsilon = 1e-5f32;
58        let mut axis = -1i64;
59
60        for attr in node.attribute.as_slice() {
61            match attr.name.as_str() {
62                "epsilon" if attr.f != 0.0 => {
63                    epsilon = attr.f;
64                }
65                "axis" if attr.i != 0 => {
66                    axis = attr.i;
67                }
68                _ => {}
69            }
70        }
71
72        let output_name = if node.output.as_slice().is_empty() {
73            format!("{}_output", node_name)
74        } else {
75            sanitize_identifier(&node.output.as_slice()[0].to_string())
76        };
77
78        let mut options = Map::new();
79        options.insert("epsilon".to_string(), serde_json::json!(epsilon));
80
81        // WebNN layerNormalization uses positive axes.
82        if let Some(rank) = context.input_rank(inputs[0].as_str()) {
83            let normalized_axis = normalize_axis_best_effort(axis, rank);
84            options.insert("axes".to_string(), serde_json::json!([normalized_axis]));
85        } else if axis != -1 {
86            options.insert("axes".to_string(), serde_json::json!([axis]));
87        }
88
89        // LayerNormalization can have scale and bias as inputs
90        let webnn_inputs = if inputs.len() >= 3 {
91            // Input, scale, bias
92            let input0 = context.resolve_input(&inputs[0]);
93            let input1 = context.resolve_input(&inputs[1]);
94            let input2 = context.resolve_input(&inputs[2]);
95            vec![input0, input1, input2]
96        } else if inputs.len() == 2 {
97            // Input, scale
98            let input0 = context.resolve_input(&inputs[0]);
99            let input1 = context.resolve_input(&inputs[1]);
100            vec![input0, input1]
101        } else {
102            // Just input
103            let input0 = context.resolve_input(&inputs[0]);
104            vec![input0]
105        };
106
107        let mut result = ConversionResult::new(vec![Node {
108            id: output_name.clone(),
109            op: "layerNormalization".to_string(),
110            inputs: webnn_inputs,
111            options,
112            outputs: None,
113        }]);
114
115        if let Some(output) = node.output.as_slice().first() {
116            result
117                .output_mappings
118                .insert(output.to_string(), output_name.clone());
119        }
120
121        Ok(result)
122    }
123
124    /// Convert ONNX Softmax to WebNN softmax
125    fn convert_softmax(
126        &self,
127        node: &NodeProto,
128        node_name: &str,
129        context: &ConversionContext,
130    ) -> Result<ConversionResult, OnnxError> {
131        let inputs = node.input.as_slice();
132        if inputs.len() != 1 {
133            return Err(OnnxError::InvalidShape(format!(
134                "Softmax expects 1 input, got {}",
135                inputs.len()
136            )));
137        }
138
139        // Extract axis attribute
140        let mut axis = -1i64;
141        for attr in node.attribute.as_slice() {
142            if attr.name.as_str() == "axis" && attr.i != 0 {
143                axis = attr.i;
144            }
145        }
146
147        let output_name = if node.output.as_slice().is_empty() {
148            format!("{}_output", node_name)
149        } else {
150            sanitize_identifier(&node.output.as_slice()[0].to_string())
151        };
152
153        let input0 = context.resolve_input(&inputs[0]);
154
155        let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
156            normalize_axis_best_effort(axis, rank)
157        } else {
158            axis
159        };
160
161        let mut options = Map::new();
162        options.insert("axis".to_string(), serde_json::json!(axis));
163
164        let mut result = ConversionResult::new(vec![Node {
165            id: output_name.clone(),
166            op: "softmax".to_string(),
167            inputs: vec![input0],
168            options,
169            outputs: None,
170        }]);
171
172        if let Some(output) = node.output.as_slice().first() {
173            result
174                .output_mappings
175                .insert(output.to_string(), output_name.clone());
176        }
177
178        Ok(result)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::protos::onnx::{AttributeProto, NodeProto};
186
187    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
188        NodeProto {
189            op_type: op_type.to_string(),
190            name: format!("test_{}", op_type.to_lowercase()),
191            input: inputs.iter().map(|s| s.to_string()).collect(),
192            output: outputs.iter().map(|s| s.to_string()).collect(),
193            ..Default::default()
194        }
195    }
196
197    fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
198        let attr = AttributeProto {
199            name: name.to_string(),
200            i: value,
201            ..Default::default()
202        };
203        node.attribute.push(attr);
204    }
205
206    #[test]
207    fn test_normalization_handler_supports() {
208        let handler = NormalizationHandler;
209        assert!(handler.supports("LayerNormalization"));
210        assert!(handler.supports("Softmax"));
211        assert!(!handler.supports("Add"));
212    }
213
214    #[test]
215    fn test_convert_softmax() {
216        let handler = NormalizationHandler;
217        let mut node = create_test_node("Softmax", vec!["x"], vec!["y"]);
218        add_int_attribute(&mut node, "axis", -1);
219        let initializers = std::collections::HashMap::new();
220        let mut value_shapes = std::collections::HashMap::new();
221        value_shapes.insert("x".to_string(), vec![1, 128, 384]);
222        let const_values = std::collections::HashMap::new();
223        let value_ids = std::collections::HashMap::new();
224        let value_types = std::collections::HashMap::new();
225        let context = ConversionContext {
226            initializers: &initializers,
227            value_shapes: &value_shapes,
228            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
229            const_values: &const_values,
230            value_ids: &value_ids,
231            value_types: &value_types,
232        };
233
234        let result = handler.convert(&node, &context).unwrap();
235        assert_eq!(result.nodes.len(), 1);
236        assert_eq!(result.nodes[0].op, "softmax");
237        assert_eq!(result.nodes[0].inputs, vec!["x"]);
238        assert_eq!(result.nodes[0].id, "y");
239        assert!(result.nodes[0].options.contains_key("axis"));
240        assert_eq!(
241            result.nodes[0].options.get("axis"),
242            Some(&serde_json::json!(2))
243        );
244    }
245
246    #[test]
247    fn test_convert_layer_norm() {
248        let handler = NormalizationHandler;
249        let mut node =
250            create_test_node("LayerNormalization", vec!["x", "scale", "bias"], vec!["y"]);
251        add_int_attribute(&mut node, "axis", -1);
252        let initializers = std::collections::HashMap::new();
253        let mut value_shapes = std::collections::HashMap::new();
254        value_shapes.insert("x".to_string(), vec![1, 128, 384]);
255        let const_values = std::collections::HashMap::new();
256        let value_ids = std::collections::HashMap::new();
257        let value_types = std::collections::HashMap::new();
258        let context = ConversionContext {
259            initializers: &initializers,
260            value_shapes: &value_shapes,
261            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
262            const_values: &const_values,
263            value_ids: &value_ids,
264            value_types: &value_types,
265        };
266
267        let result = handler.convert(&node, &context).unwrap();
268        assert_eq!(result.nodes.len(), 1);
269        assert_eq!(result.nodes[0].op, "layerNormalization");
270        assert_eq!(result.nodes[0].inputs.len(), 3);
271        assert!(result.nodes[0].options.contains_key("epsilon"));
272        assert_eq!(
273            result.nodes[0].options.get("axes"),
274            Some(&serde_json::json!([2]))
275        );
276    }
277}