1use crate::ast::{ConstInit, GraphJson, Node};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum SerializeError {
6 #[error("invalid format: {0}")]
7 InvalidFormat(String),
8 #[error("unsupported version: {0}")]
9 UnsupportedVersion(u32),
10}
11
12pub fn serialize_graph_to_wg_text(graph: &GraphJson) -> Result<String, SerializeError> {
13 let mut output = String::new();
14
15 if graph.format != "webnn-graph-json" {
17 return Err(SerializeError::InvalidFormat(graph.format.clone()));
18 }
19 if graph.version != 1 {
20 return Err(SerializeError::UnsupportedVersion(graph.version));
21 }
22
23 let name = graph.name.as_deref().unwrap_or("graph");
25 output.push_str(&format!("webnn_graph \"{}\" v1 {{\n", escape_string(name)));
26
27 if !graph.inputs.is_empty() {
29 output.push_str(" inputs {\n");
30 for (name, desc) in &graph.inputs {
31 let dtype = desc.data_type.to_wg_text();
32 let shape = serialize_shape(&desc.shape);
33 output.push_str(&format!(" {}: {}{};\n", name, dtype, shape));
34 }
35 output.push_str(" }\n\n");
36 }
37
38 if !graph.consts.is_empty() {
40 output.push_str(" consts {\n");
41 for (name, const_decl) in &graph.consts {
42 let dtype = const_decl.data_type.to_wg_text();
43 let shape = serialize_shape(&const_decl.shape);
44 let annotation = serialize_const_init(&const_decl.init);
45 output.push_str(&format!(" {}: {}{}{}", name, dtype, shape, annotation));
46 output.push_str(";\n");
47 }
48 output.push_str(" }\n\n");
49 }
50
51 if !graph.nodes.is_empty() {
53 output.push_str(" nodes {\n");
54 for node in &graph.nodes {
55 output.push_str(&format!(" {}\n", serialize_node(node)));
56 }
57 output.push_str(" }\n\n");
58 }
59
60 output.push_str(" outputs {");
62 if !graph.outputs.is_empty() {
63 let outputs: Vec<String> = graph.outputs.keys().map(|k| format!(" {};", k)).collect();
64 output.push_str(&outputs.join(""));
65 output.push(' ');
66 }
67 output.push_str("}\n");
68
69 output.push_str("}\n");
70 Ok(output)
71}
72
73fn serialize_shape(shape: &[u32]) -> String {
74 let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
75 format!("[{}]", dims.join(", "))
76}
77
78fn serialize_const_init(init: &ConstInit) -> String {
79 match init {
80 ConstInit::Weights { r#ref } => {
81 format!(" @weights(\"{}\")", escape_string(r#ref))
82 }
83 ConstInit::Scalar { value } => {
84 format!(" @scalar({})", serialize_json_value(value))
85 }
86 ConstInit::InlineBytes { bytes } => {
87 format!(" @bytes({:?})", bytes)
90 }
91 }
92}
93
94fn serialize_node(node: &Node) -> String {
95 let call = serialize_call(&node.op, &node.inputs, &node.options);
96
97 if let Some(outputs) = &node.outputs {
98 let out_list = outputs.join(", ");
100 format!("[{}] = {};", out_list, call)
101 } else {
102 format!("{} = {};", node.id, call)
104 }
105}
106
107fn serialize_call(
108 op: &str,
109 inputs: &[String],
110 options: &serde_json::Map<String, serde_json::Value>,
111) -> String {
112 let mut args = Vec::new();
113
114 for input in inputs {
116 args.push(input.clone());
117 }
118
119 for (key, value) in options {
121 args.push(format!("{}={}", key, serialize_json_value(value)));
122 }
123
124 format!("{}({})", op, args.join(", "))
125}
126
127fn serialize_json_value(value: &serde_json::Value) -> String {
128 match value {
129 serde_json::Value::Null => "null".to_string(),
130 serde_json::Value::Bool(b) => b.to_string(),
131 serde_json::Value::Number(n) => n.to_string(),
132 serde_json::Value::String(s) => format!("\"{}\"", escape_string(s)),
133 serde_json::Value::Array(arr) => {
134 let items: Vec<String> = arr.iter().map(serialize_json_value).collect();
135 format!("[{}]", items.join(", "))
136 }
137 serde_json::Value::Object(_) => {
138 value.to_string()
140 }
141 }
142}
143
144fn escape_string(s: &str) -> String {
145 s.replace('\\', "\\\\").replace('"', "\\\"")
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, Node, OperandDesc};
152 use crate::parser::parse_wg_text;
153
154 #[test]
155 fn test_serialize_simple_graph() {
156 let mut g = new_graph_json();
157 g.name = Some("test".to_string());
158 g.inputs.insert(
159 "x".to_string(),
160 OperandDesc {
161 data_type: DataType::Float32,
162 shape: vec![1, 10],
163 },
164 );
165 g.nodes.push(Node {
166 id: "result".to_string(),
167 op: "relu".to_string(),
168 inputs: vec!["x".to_string()],
169 options: serde_json::Map::new(),
170 outputs: None,
171 });
172 g.outputs.insert("result".to_string(), "result".to_string());
173
174 let text = serialize_graph_to_wg_text(&g).unwrap();
175 assert!(text.contains("webnn_graph \"test\" v1"));
176 assert!(text.contains("inputs {"));
177 assert!(text.contains("x: f32[1, 10];"));
178 assert!(text.contains("nodes {"));
179 assert!(text.contains("result = relu(x);"));
180 assert!(text.contains("outputs { result; }"));
181 }
182
183 #[test]
184 fn test_serialize_weights_annotation() {
185 let mut g = new_graph_json();
186 g.name = Some("test".to_string());
187 g.consts.insert(
188 "W".to_string(),
189 ConstDecl {
190 data_type: DataType::Float32,
191 shape: vec![10, 5],
192 init: ConstInit::Weights {
193 r#ref: "W".to_string(),
194 },
195 },
196 );
197 g.outputs.insert("W".to_string(), "W".to_string());
198
199 let text = serialize_graph_to_wg_text(&g).unwrap();
200 assert!(text.contains("W: f32[10, 5] @weights(\"W\");"));
201 }
202
203 #[test]
204 fn test_serialize_scalar_annotation() {
205 let mut g = new_graph_json();
206 g.name = Some("test".to_string());
207 g.consts.insert(
208 "scale".to_string(),
209 ConstDecl {
210 data_type: DataType::Float32,
211 shape: vec![1],
212 init: ConstInit::Scalar {
213 value: serde_json::json!(3.5),
214 },
215 },
216 );
217 g.outputs.insert("scale".to_string(), "scale".to_string());
218
219 let text = serialize_graph_to_wg_text(&g).unwrap();
220 assert!(text.contains("scale: f32[1] @scalar(3.5);"));
221 }
222
223 #[test]
224 fn test_serialize_multi_output_node() {
225 let mut g = new_graph_json();
226 g.name = Some("test".to_string());
227 g.inputs.insert(
228 "x".to_string(),
229 OperandDesc {
230 data_type: DataType::Float32,
231 shape: vec![10],
232 },
233 );
234 g.nodes.push(Node {
235 id: "a".to_string(),
236 op: "split".to_string(),
237 inputs: vec!["x".to_string()],
238 options: serde_json::Map::new(),
239 outputs: Some(vec!["a".to_string(), "b".to_string()]),
240 });
241 g.outputs.insert("a".to_string(), "a".to_string());
242
243 let text = serialize_graph_to_wg_text(&g).unwrap();
244 assert!(text.contains("[a, b] = split(x);"));
245 }
246
247 #[test]
248 fn test_serialize_node_options() {
249 let mut g = new_graph_json();
250 g.name = Some("test".to_string());
251 g.inputs.insert(
252 "x".to_string(),
253 OperandDesc {
254 data_type: DataType::Float32,
255 shape: vec![1, 10],
256 },
257 );
258
259 let mut options = serde_json::Map::new();
260 options.insert("axis".to_string(), serde_json::json!(1));
261 options.insert("keepdims".to_string(), serde_json::json!(true));
262
263 g.nodes.push(Node {
264 id: "result".to_string(),
265 op: "softmax".to_string(),
266 inputs: vec!["x".to_string()],
267 options,
268 outputs: None,
269 });
270 g.outputs.insert("result".to_string(), "result".to_string());
271
272 let text = serialize_graph_to_wg_text(&g).unwrap();
273 assert!(text.contains("softmax(x,"));
274 assert!(text.contains("axis=1"));
275 assert!(text.contains("keepdims=true"));
276 }
277
278 #[test]
279 fn test_serialize_various_dtypes() {
280 let mut g = new_graph_json();
281 g.name = Some("test".to_string());
282
283 let dtypes = vec![
284 ("f32_input", DataType::Float32),
285 ("f16_input", DataType::Float16),
286 ("i32_input", DataType::Int32),
287 ("u32_input", DataType::Uint32),
288 ("i64_input", DataType::Int64),
289 ("u64_input", DataType::Uint64),
290 ("i8_input", DataType::Int8),
291 ("u8_input", DataType::Uint8),
292 ];
293
294 for (name, dtype) in dtypes {
295 g.inputs.insert(
296 name.to_string(),
297 OperandDesc {
298 data_type: dtype,
299 shape: vec![1],
300 },
301 );
302 }
303 g.outputs
304 .insert("f32_input".to_string(), "f32_input".to_string());
305
306 let text = serialize_graph_to_wg_text(&g).unwrap();
307 assert!(text.contains("f32_input: f32[1];"));
308 assert!(text.contains("f16_input: f16[1];"));
309 assert!(text.contains("i32_input: i32[1];"));
310 assert!(text.contains("u32_input: u32[1];"));
311 assert!(text.contains("i64_input: i64[1];"));
312 assert!(text.contains("u64_input: u64[1];"));
313 assert!(text.contains("i8_input: i8[1];"));
314 assert!(text.contains("u8_input: u8[1];"));
315 }
316
317 #[test]
318 fn test_roundtrip() {
319 let input = r#"
320webnn_graph "resnet_head" v1 {
321 inputs {
322 x: f32[1, 2048];
323 }
324 consts {
325 W: f32[2048, 1000] @weights("W");
326 b: f32[1000] @weights("b");
327 }
328 nodes {
329 logits0 = matmul(x, W);
330 logits = add(logits0, b);
331 probs = softmax(logits, axis=1);
332 }
333 outputs { probs; }
334}
335"#;
336 let graph = parse_wg_text(input).unwrap();
338
339 let serialized = serialize_graph_to_wg_text(&graph).unwrap();
341
342 let graph2 = parse_wg_text(&serialized).unwrap();
344
345 assert_eq!(graph.name, graph2.name);
347 assert_eq!(graph.inputs.len(), graph2.inputs.len());
348 assert_eq!(graph.consts.len(), graph2.consts.len());
349 assert_eq!(graph.nodes.len(), graph2.nodes.len());
350 assert_eq!(graph.outputs.len(), graph2.outputs.len());
351 }
352
353 #[test]
354 fn test_default_graph_name() {
355 let mut g = new_graph_json();
356 g.outputs.insert("x".to_string(), "x".to_string());
358
359 let text = serialize_graph_to_wg_text(&g).unwrap();
360 assert!(text.contains("webnn_graph \"graph\" v1"));
361 }
362
363 #[test]
364 fn test_string_escaping() {
365 let mut g = new_graph_json();
366 g.name = Some("test\"with\\quotes".to_string());
367 g.outputs.insert("x".to_string(), "x".to_string());
368
369 let text = serialize_graph_to_wg_text(&g).unwrap();
370 assert!(text.contains("webnn_graph \"test\\\"with\\\\quotes\" v1"));
371 }
372
373 #[test]
374 fn test_value_types() {
375 let mut g = new_graph_json();
376 g.name = Some("test".to_string());
377 g.inputs.insert(
378 "x".to_string(),
379 OperandDesc {
380 data_type: DataType::Float32,
381 shape: vec![1],
382 },
383 );
384
385 let mut options = serde_json::Map::new();
386 options.insert("int_val".to_string(), serde_json::json!(42));
387 options.insert("float_val".to_string(), serde_json::json!(3.5));
388 options.insert("bool_val".to_string(), serde_json::json!(true));
389 options.insert("null_val".to_string(), serde_json::json!(null));
390 options.insert("array_val".to_string(), serde_json::json!([1, 2, 3]));
391
392 g.nodes.push(Node {
393 id: "result".to_string(),
394 op: "test_op".to_string(),
395 inputs: vec!["x".to_string()],
396 options,
397 outputs: None,
398 });
399 g.outputs.insert("result".to_string(), "result".to_string());
400
401 let text = serialize_graph_to_wg_text(&g).unwrap();
402 assert!(text.contains("int_val=42"));
403 assert!(text.contains("float_val=3.5"));
404 assert!(text.contains("bool_val=true"));
405 assert!(text.contains("null_val=null"));
406 assert!(text.contains("array_val=[1, 2, 3]"));
407 }
408}