1use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ConversionHandler;
10
11fn dtype_to_webnn_string(dt: &crate::ast::DataType) -> &'static str {
12 match dt {
13 crate::ast::DataType::Float32 => "float32",
14 crate::ast::DataType::Float16 => "float16",
15 crate::ast::DataType::Int4 => "int4",
16 crate::ast::DataType::Uint4 => "uint4",
17 crate::ast::DataType::Int32 => "int32",
18 crate::ast::DataType::Uint32 => "uint32",
19 crate::ast::DataType::Int64 => "int64",
20 crate::ast::DataType::Uint64 => "uint64",
21 crate::ast::DataType::Int8 => "int8",
22 crate::ast::DataType::Uint8 => "uint8",
23 }
24}
25
26impl OpHandler for ConversionHandler {
27 fn supports(&self, op_type: &str) -> bool {
28 matches!(op_type, "Cast" | "Constant")
29 }
30
31 fn convert(
32 &self,
33 node: &NodeProto,
34 context: &ConversionContext,
35 ) -> Result<ConversionResult, OnnxError> {
36 let op_type = node.op_type.as_str();
37 let node_name = if !node.name.is_empty() {
38 node.name.as_str().to_string()
39 } else {
40 "unnamed".to_string()
41 };
42
43 match op_type {
44 "Cast" => self.convert_cast(node, &node_name, context),
45 "Constant" => self.convert_constant(node, &node_name),
46 _ => Err(OnnxError::UnsupportedOp {
47 op: op_type.to_string(),
48 node: node_name,
49 }),
50 }
51 }
52}
53
54impl ConversionHandler {
55 fn convert_cast(
58 &self,
59 node: &NodeProto,
60 node_name: &str,
61 context: &ConversionContext,
62 ) -> Result<ConversionResult, OnnxError> {
63 let inputs = node.input.as_slice();
64 if inputs.len() != 1 {
65 return Err(OnnxError::InvalidShape(format!(
66 "Cast expects 1 input, got {}",
67 inputs.len()
68 )));
69 }
70
71 let mut to_type: Option<i64> = None;
73 for attr in node.attribute.as_slice() {
74 if attr.name.as_str() == "to" && attr.i != 0 {
75 to_type = Some(attr.i);
76 }
77 }
78
79 if to_type.is_none() {
80 return Err(OnnxError::MissingAttribute {
81 attr: "to".to_string(),
82 op: "Cast".to_string(),
83 });
84 }
85
86 let output_name = if node.output.as_slice().is_empty() {
87 format!("{}_output", node_name)
88 } else {
89 sanitize_identifier(&node.output.as_slice()[0].to_string())
90 };
91
92 let input0 = context.resolve_input(&inputs[0]);
93
94 let target_type = crate::onnx::convert::map_onnx_data_type(to_type.unwrap() as i32)?;
96
97 let mut options = Map::new();
98 options.insert(
99 "to".to_string(),
100 serde_json::json!(dtype_to_webnn_string(&target_type)),
101 );
102
103 let mut result = ConversionResult::new(vec![Node {
104 id: output_name.clone(),
105 op: "cast".to_string(),
106 inputs: vec![input0],
107 options,
108 outputs: None,
109 }]);
110
111 if let Some(output) = node.output.as_slice().first() {
112 result
113 .output_mappings
114 .insert(output.to_string(), output_name.clone());
115 }
116
117 Ok(result)
118 }
119
120 fn convert_constant(
123 &self,
124 node: &NodeProto,
125 node_name: &str,
126 ) -> Result<ConversionResult, OnnxError> {
127 let output_name = if node.output.as_slice().is_empty() {
128 format!("{}_output", node_name)
129 } else {
130 sanitize_identifier(&node.output.as_slice()[0].to_string())
131 };
132
133 let tensor = node
135 .attribute
136 .as_slice()
137 .iter()
138 .find_map(|attr| {
139 if attr.name.as_str() == "value" {
140 attr.t.as_ref()
141 } else {
142 None
143 }
144 })
145 .ok_or_else(|| OnnxError::MissingAttribute {
146 attr: "value".to_string(),
147 op: "Constant".to_string(),
148 })?;
149 let onnx_type = tensor.data_type;
150 let data_type = crate::onnx::convert::map_onnx_data_type(onnx_type)?;
151
152 let shape: Vec<i64> = tensor.dims.as_slice().to_vec();
153 let raw_data = tensor.raw_data.as_slice().to_vec();
154
155 let mut options = Map::new();
156 options.insert(
157 "dataType".to_string(),
158 serde_json::json!(dtype_to_webnn_string(&data_type)),
159 );
160 options.insert("shape".to_string(), serde_json::json!(shape));
161
162 let b64_data =
164 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &raw_data);
165 options.insert("data".to_string(), serde_json::json!(b64_data));
166
167 let mut result = ConversionResult::new(vec![Node {
168 id: output_name.clone(),
169 op: "constant".to_string(),
170 inputs: vec![],
171 options,
172 outputs: None,
173 }]);
174
175 if let Some(output) = node.output.as_slice().first() {
176 result
177 .output_mappings
178 .insert(output.to_string(), output_name.clone());
179 }
180
181 Ok(result)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::protos::onnx::{AttributeProto, NodeProto};
189
190 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
191 NodeProto {
192 op_type: op_type.to_string(),
193 name: format!("test_{}", op_type.to_lowercase()),
194 input: inputs.iter().map(|s| s.to_string()).collect(),
195 output: outputs.iter().map(|s| s.to_string()).collect(),
196 ..Default::default()
197 }
198 }
199
200 fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
201 let attr = AttributeProto {
202 name: name.to_string(),
203 i: value,
204 ..Default::default()
205 };
206 node.attribute.push(attr);
207 }
208
209 #[test]
210 fn test_conversion_handler_supports() {
211 let handler = ConversionHandler;
212 assert!(handler.supports("Cast"));
213 assert!(!handler.supports("Add"));
214 }
215
216 #[test]
217 fn test_convert_cast() {
218 let handler = ConversionHandler;
219 let mut node = create_test_node("Cast", vec!["x"], vec!["y"]);
220 add_int_attribute(&mut node, "to", 7); let initializers = std::collections::HashMap::new();
222 let value_shapes = std::collections::HashMap::new();
223 let const_values = std::collections::HashMap::new();
224 let value_ids = std::collections::HashMap::new();
225 let value_types = std::collections::HashMap::new();
226 let context = ConversionContext {
227 initializers: &initializers,
228 value_shapes: &value_shapes,
229 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
230 const_values: &const_values,
231 value_ids: &value_ids,
232 value_types: &value_types,
233 };
234
235 let result = handler.convert(&node, &context).unwrap();
236 assert_eq!(result.nodes.len(), 1);
237 assert_eq!(result.nodes[0].op, "cast");
238 assert_eq!(result.nodes[0].inputs, vec!["x"]);
239 assert!(result.nodes[0].options.contains_key("to"));
240 assert_eq!(
241 result.nodes[0].options.get("to"),
242 Some(&serde_json::json!("int64"))
243 );
244 }
245
246 #[test]
247 fn test_convert_constant_uses_lowercase_dtype_and_base64_data() {
248 let handler = ConversionHandler;
249 let mut node = create_test_node("Constant", vec![], vec!["c0"]);
250 let tensor = crate::protos::onnx::TensorProto {
251 data_type: crate::protos::onnx::TensorProto_DataType::Float as i32,
252 dims: vec![1],
253 raw_data: vec![0, 0, 128, 63], ..Default::default()
255 };
256 node.attribute.push(AttributeProto {
257 name: "value".to_string(),
258 t: Some(tensor),
259 ..Default::default()
260 });
261
262 let result = handler
263 .convert(
264 &node,
265 &ConversionContext {
266 initializers: &std::collections::HashMap::new(),
267 value_shapes: &std::collections::HashMap::new(),
268 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
269 const_values: &std::collections::HashMap::new(),
270 value_ids: &std::collections::HashMap::new(),
271 value_types: &std::collections::HashMap::new(),
272 },
273 )
274 .unwrap();
275
276 assert_eq!(result.nodes.len(), 1);
277 assert_eq!(
278 result.nodes[0].options.get("dataType"),
279 Some(&serde_json::json!("float32"))
280 );
281 assert!(result.nodes[0].options.get("data").is_some());
282 }
283}