Skip to main content

webnn_graph/onnx/ops/
conversion.rs

1// Type conversion and constant operators: Cast, Constant
2
3use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ConversionHandler;
10
11fn dtype_to_webnn_string(dt: &crate::ast::DataType) -> &'static str {
12    match dt {
13        crate::ast::DataType::Float32 => "float32",
14        crate::ast::DataType::Float16 => "float16",
15        crate::ast::DataType::Int4 => "int4",
16        crate::ast::DataType::Uint4 => "uint4",
17        crate::ast::DataType::Int32 => "int32",
18        crate::ast::DataType::Uint32 => "uint32",
19        crate::ast::DataType::Int64 => "int64",
20        crate::ast::DataType::Uint64 => "uint64",
21        crate::ast::DataType::Int8 => "int8",
22        crate::ast::DataType::Uint8 => "uint8",
23    }
24}
25
26impl OpHandler for ConversionHandler {
27    fn supports(&self, op_type: &str) -> bool {
28        matches!(op_type, "Cast" | "Constant")
29    }
30
31    fn convert(
32        &self,
33        node: &NodeProto,
34        context: &ConversionContext,
35    ) -> Result<ConversionResult, OnnxError> {
36        let op_type = node.op_type.as_str();
37        let node_name = if !node.name.is_empty() {
38            node.name.as_str().to_string()
39        } else {
40            "unnamed".to_string()
41        };
42
43        match op_type {
44            "Cast" => self.convert_cast(node, &node_name, context),
45            "Constant" => self.convert_constant(node, &node_name),
46            _ => Err(OnnxError::UnsupportedOp {
47                op: op_type.to_string(),
48                node: node_name,
49            }),
50        }
51    }
52}
53
54impl ConversionHandler {
55    /// Convert ONNX Cast to WebNN cast
56    /// ONNX Cast converts tensor data type
57    fn convert_cast(
58        &self,
59        node: &NodeProto,
60        node_name: &str,
61        context: &ConversionContext,
62    ) -> Result<ConversionResult, OnnxError> {
63        let inputs = node.input.as_slice();
64        if inputs.len() != 1 {
65            return Err(OnnxError::InvalidShape(format!(
66                "Cast expects 1 input, got {}",
67                inputs.len()
68            )));
69        }
70
71        // Extract 'to' attribute (target data type)
72        let mut to_type: Option<i64> = None;
73        for attr in node.attribute.as_slice() {
74            if attr.name.as_str() == "to" && attr.i != 0 {
75                to_type = Some(attr.i);
76            }
77        }
78
79        if to_type.is_none() {
80            return Err(OnnxError::MissingAttribute {
81                attr: "to".to_string(),
82                op: "Cast".to_string(),
83            });
84        }
85
86        let output_name = if node.output.as_slice().is_empty() {
87            format!("{}_output", node_name)
88        } else {
89            sanitize_identifier(&node.output.as_slice()[0].to_string())
90        };
91
92        let input0 = context.resolve_input(&inputs[0]);
93
94        // Map ONNX type to WebNN DataType
95        let target_type = crate::onnx::convert::map_onnx_data_type(to_type.unwrap() as i32)?;
96
97        let mut options = Map::new();
98        options.insert(
99            "to".to_string(),
100            serde_json::json!(dtype_to_webnn_string(&target_type)),
101        );
102
103        let mut result = ConversionResult::new(vec![Node {
104            id: output_name.clone(),
105            op: "cast".to_string(),
106            inputs: vec![input0],
107            options,
108            outputs: None,
109        }]);
110
111        if let Some(output) = node.output.as_slice().first() {
112            result
113                .output_mappings
114                .insert(output.to_string(), output_name.clone());
115        }
116
117        Ok(result)
118    }
119
120    /// Convert ONNX Constant to WebNN constant
121    /// ONNX Constant creates an inline constant tensor
122    fn convert_constant(
123        &self,
124        node: &NodeProto,
125        node_name: &str,
126    ) -> Result<ConversionResult, OnnxError> {
127        let output_name = if node.output.as_slice().is_empty() {
128            format!("{}_output", node_name)
129        } else {
130            sanitize_identifier(&node.output.as_slice()[0].to_string())
131        };
132
133        // Extract 'value' attribute (TensorProto)
134        let tensor = node
135            .attribute
136            .as_slice()
137            .iter()
138            .find_map(|attr| {
139                if attr.name.as_str() == "value" {
140                    attr.t.as_ref()
141                } else {
142                    None
143                }
144            })
145            .ok_or_else(|| OnnxError::MissingAttribute {
146                attr: "value".to_string(),
147                op: "Constant".to_string(),
148            })?;
149        let onnx_type = tensor.data_type;
150        let data_type = crate::onnx::convert::map_onnx_data_type(onnx_type)?;
151
152        let shape: Vec<i64> = tensor.dims.as_slice().to_vec();
153        let raw_data = tensor.raw_data.as_slice().to_vec();
154
155        let mut options = Map::new();
156        options.insert(
157            "dataType".to_string(),
158            serde_json::json!(dtype_to_webnn_string(&data_type)),
159        );
160        options.insert("shape".to_string(), serde_json::json!(shape));
161
162        // Store raw bytes as base64 for now (WebNN implementation can decode)
163        let b64_data =
164            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &raw_data);
165        options.insert("data".to_string(), serde_json::json!(b64_data));
166
167        let mut result = ConversionResult::new(vec![Node {
168            id: output_name.clone(),
169            op: "constant".to_string(),
170            inputs: vec![],
171            options,
172            outputs: None,
173        }]);
174
175        if let Some(output) = node.output.as_slice().first() {
176            result
177                .output_mappings
178                .insert(output.to_string(), output_name.clone());
179        }
180
181        Ok(result)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::protos::onnx::{AttributeProto, NodeProto};
189
190    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
191        NodeProto {
192            op_type: op_type.to_string(),
193            name: format!("test_{}", op_type.to_lowercase()),
194            input: inputs.iter().map(|s| s.to_string()).collect(),
195            output: outputs.iter().map(|s| s.to_string()).collect(),
196            ..Default::default()
197        }
198    }
199
200    fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
201        let attr = AttributeProto {
202            name: name.to_string(),
203            i: value,
204            ..Default::default()
205        };
206        node.attribute.push(attr);
207    }
208
209    #[test]
210    fn test_conversion_handler_supports() {
211        let handler = ConversionHandler;
212        assert!(handler.supports("Cast"));
213        assert!(!handler.supports("Add"));
214    }
215
216    #[test]
217    fn test_convert_cast() {
218        let handler = ConversionHandler;
219        let mut node = create_test_node("Cast", vec!["x"], vec!["y"]);
220        add_int_attribute(&mut node, "to", 7); // INT64
221        let initializers = std::collections::HashMap::new();
222        let value_shapes = std::collections::HashMap::new();
223        let const_values = std::collections::HashMap::new();
224        let value_ids = std::collections::HashMap::new();
225        let value_types = std::collections::HashMap::new();
226        let context = ConversionContext {
227            initializers: &initializers,
228            value_shapes: &value_shapes,
229            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
230            const_values: &const_values,
231            value_ids: &value_ids,
232            value_types: &value_types,
233        };
234
235        let result = handler.convert(&node, &context).unwrap();
236        assert_eq!(result.nodes.len(), 1);
237        assert_eq!(result.nodes[0].op, "cast");
238        assert_eq!(result.nodes[0].inputs, vec!["x"]);
239        assert!(result.nodes[0].options.contains_key("to"));
240        assert_eq!(
241            result.nodes[0].options.get("to"),
242            Some(&serde_json::json!("int64"))
243        );
244    }
245
246    #[test]
247    fn test_convert_constant_uses_lowercase_dtype_and_base64_data() {
248        let handler = ConversionHandler;
249        let mut node = create_test_node("Constant", vec![], vec!["c0"]);
250        let tensor = crate::protos::onnx::TensorProto {
251            data_type: crate::protos::onnx::TensorProto_DataType::Float as i32,
252            dims: vec![1],
253            raw_data: vec![0, 0, 128, 63], // 1.0f32
254            ..Default::default()
255        };
256        node.attribute.push(AttributeProto {
257            name: "value".to_string(),
258            t: Some(tensor),
259            ..Default::default()
260        });
261
262        let result = handler
263            .convert(
264                &node,
265                &ConversionContext {
266                    initializers: &std::collections::HashMap::new(),
267                    value_shapes: &std::collections::HashMap::new(),
268                    value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
269                    const_values: &std::collections::HashMap::new(),
270                    value_ids: &std::collections::HashMap::new(),
271                    value_types: &std::collections::HashMap::new(),
272                },
273            )
274            .unwrap();
275
276        assert_eq!(result.nodes.len(), 1);
277        assert_eq!(
278            result.nodes[0].options.get("dataType"),
279            Some(&serde_json::json!("float32"))
280        );
281        assert!(result.nodes[0].options.get("data").is_some());
282    }
283}