Skip to main content

webnn_graph/
serialize.rs

1use crate::ast::{ConstInit, Dimension, GraphJson, Node};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum SerializeError {
6    #[error("invalid format: {0}")]
7    InvalidFormat(String),
8    #[error("unsupported version: {0}")]
9    UnsupportedVersion(u32),
10}
11
12/// Serialization options.
13#[derive(Debug, Clone, Copy, Default)]
14pub struct SerializeOptions {
15    /// When true, emits `@quantized` annotation on the graph header to signal
16    /// that tensors/constants are already quantized and should be preserved as-is.
17    /// Downstream tools can use this hint to avoid dequantizing or re-quantizing.
18    pub quantized: bool,
19}
20
21pub fn serialize_graph_to_wg_text(
22    graph: &GraphJson,
23    opts: SerializeOptions,
24) -> Result<String, SerializeError> {
25    let mut output = String::new();
26
27    // Validate format
28    if graph.format != "webnn-graph-json" {
29        return Err(SerializeError::InvalidFormat(graph.format.clone()));
30    }
31    if graph.version != 1 && graph.version != 2 {
32        return Err(SerializeError::UnsupportedVersion(graph.version));
33    }
34
35    // Header
36    let name = graph.name.as_deref().unwrap_or("graph");
37    let quantized_flag = if opts.quantized || graph.quantized {
38        " @quantized"
39    } else {
40        ""
41    };
42    output.push_str(&format!(
43        "webnn_graph \"{}\" v{}{} {{\n",
44        escape_string(name),
45        graph.version,
46        quantized_flag
47    ));
48
49    // Inputs block
50    if !graph.inputs.is_empty() {
51        output.push_str("  inputs {\n");
52        for (name, desc) in &graph.inputs {
53            let dtype = desc.data_type.to_wg_text();
54            let shape = serialize_shape(&desc.shape)?;
55            output.push_str(&format!("    {}: {}{};\n", name, dtype, shape));
56        }
57        output.push_str("  }\n\n");
58    }
59
60    // Consts block
61    if !graph.consts.is_empty() {
62        output.push_str("  consts {\n");
63        for (name, const_decl) in &graph.consts {
64            let dtype = const_decl.data_type.to_wg_text();
65            let shape = serialize_shape_u32(&const_decl.shape);
66            let annotation = serialize_const_init(&const_decl.init);
67            output.push_str(&format!("    {}: {}{}{}", name, dtype, shape, annotation));
68            output.push_str(";\n");
69        }
70        output.push_str("  }\n\n");
71    }
72
73    // Nodes block
74    if !graph.nodes.is_empty() {
75        output.push_str("  nodes {\n");
76        for node in &graph.nodes {
77            output.push_str(&format!("    {}\n", serialize_node(node)));
78        }
79        output.push_str("  }\n\n");
80    }
81
82    // Outputs block
83    output.push_str("  outputs {");
84    if !graph.outputs.is_empty() {
85        let outputs: Vec<String> = graph.outputs.keys().map(|k| format!(" {};", k)).collect();
86        output.push_str(&outputs.join(""));
87        output.push(' ');
88    }
89    output.push_str("}\n");
90
91    output.push_str("}\n");
92    Ok(output)
93}
94
95fn serialize_shape(shape: &[Dimension]) -> Result<String, SerializeError> {
96    let mut dims: Vec<String> = Vec::with_capacity(shape.len());
97    for dim in shape {
98        match dim {
99            Dimension::Static(v) => dims.push(v.to_string()),
100            Dimension::Dynamic(d) => {
101                dims.push(format!(
102                    "dyn(\"{}\", {})",
103                    escape_string(&d.name),
104                    d.max_size
105                ));
106            }
107        }
108    }
109    Ok(format!("[{}]", dims.join(", ")))
110}
111
112fn serialize_shape_u32(shape: &[u32]) -> String {
113    let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
114    format!("[{}]", dims.join(", "))
115}
116
117fn serialize_const_init(init: &ConstInit) -> String {
118    match init {
119        ConstInit::Weights { r#ref } => {
120            format!(" @weights(\"{}\")", escape_string(r#ref))
121        }
122        ConstInit::Scalar { value } => {
123            format!(" @scalar({})", serialize_json_value(value))
124        }
125        ConstInit::InlineBytes { bytes } => {
126            let nums: Vec<String> = bytes.iter().map(|b| b.to_string()).collect();
127            format!(" @bytes([{}])", nums.join(", "))
128        }
129    }
130}
131
132fn serialize_node(node: &Node) -> String {
133    let call = serialize_call(&node.op, &node.inputs, &node.options);
134
135    if let Some(outputs) = &node.outputs {
136        // Multi-output node: [a, b] = op(...)
137        let out_list = outputs.join(", ");
138        format!("[{}] = {};", out_list, call)
139    } else {
140        // Single output node: id = op(...)
141        format!("{} = {};", node.id, call)
142    }
143}
144
145fn serialize_call(
146    op: &str,
147    inputs: &[String],
148    options: &serde_json::Map<String, serde_json::Value>,
149) -> String {
150    let mut args = Vec::new();
151
152    // Add positional inputs
153    for input in inputs {
154        args.push(input.clone());
155    }
156
157    // Add named options
158    for (key, value) in options {
159        args.push(format!("{}={}", key, serialize_json_value(value)));
160    }
161
162    format!("{}({})", op, args.join(", "))
163}
164
165fn serialize_json_value(value: &serde_json::Value) -> String {
166    match value {
167        serde_json::Value::Null => "null".to_string(),
168        serde_json::Value::Bool(b) => b.to_string(),
169        serde_json::Value::Number(n) => n.to_string(),
170        serde_json::Value::String(s) => format!("\"{}\"", escape_string(s)),
171        serde_json::Value::Array(arr) => {
172            let items: Vec<String> = arr.iter().map(serialize_json_value).collect();
173            format!("[{}]", items.join(", "))
174        }
175        serde_json::Value::Object(_) => {
176            // Objects are not typically used in the WG text format
177            value.to_string()
178        }
179    }
180}
181
182fn escape_string(s: &str) -> String {
183    s.replace('\\', "\\\\").replace('"', "\\\"")
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::ast::{
190        new_graph_json, to_dimension_vector, ConstDecl, ConstInit, DataType, Node, OperandDesc,
191    };
192    use crate::parser::parse_wg_text;
193
194    #[test]
195    fn test_serialize_simple_graph() {
196        let mut g = new_graph_json();
197        g.name = Some("test".to_string());
198        g.inputs.insert(
199            "x".to_string(),
200            OperandDesc {
201                data_type: DataType::Float32,
202                shape: to_dimension_vector(&[1, 10]),
203            },
204        );
205        g.nodes.push(Node {
206            id: "result".to_string(),
207            op: "relu".to_string(),
208            inputs: vec!["x".to_string()],
209            options: serde_json::Map::new(),
210            outputs: None,
211        });
212        g.outputs.insert("result".to_string(), "result".to_string());
213
214        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
215        assert!(text.contains(&format!("webnn_graph \"test\" v{}", g.version)));
216        assert!(text.contains("inputs {"));
217        assert!(text.contains("x: f32[1, 10];"));
218        assert!(text.contains("nodes {"));
219        assert!(text.contains("result = relu(x);"));
220        assert!(text.contains("outputs { result; }"));
221    }
222
223    #[test]
224    fn test_serialize_dynamic_input_shape() {
225        let mut g = new_graph_json();
226        g.name = Some("dyn".to_string());
227        g.inputs.insert(
228            "x".to_string(),
229            OperandDesc {
230                data_type: DataType::Float32,
231                shape: vec![
232                    Dimension::Dynamic(crate::ast::DynamicDimension {
233                        name: "batch_size".to_string(),
234                        max_size: 8,
235                    }),
236                    Dimension::Static(128),
237                ],
238            },
239        );
240        g.outputs.insert("x".to_string(), "x".to_string());
241        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
242        assert!(text.contains("x: f32[dyn(\"batch_size\", 8), 128];"));
243    }
244
245    #[test]
246    fn test_serialize_weights_annotation() {
247        let mut g = new_graph_json();
248        g.name = Some("test".to_string());
249        g.consts.insert(
250            "W".to_string(),
251            ConstDecl {
252                data_type: DataType::Float32,
253                shape: vec![10, 5],
254                init: ConstInit::Weights {
255                    r#ref: "W".to_string(),
256                },
257            },
258        );
259        g.outputs.insert("W".to_string(), "W".to_string());
260
261        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
262        assert!(text.contains("W: f32[10, 5] @weights(\"W\");"));
263    }
264
265    #[test]
266    fn test_serialize_scalar_annotation() {
267        let mut g = new_graph_json();
268        g.name = Some("test".to_string());
269        g.consts.insert(
270            "scale".to_string(),
271            ConstDecl {
272                data_type: DataType::Float32,
273                shape: vec![1],
274                init: ConstInit::Scalar {
275                    value: serde_json::json!(3.5),
276                },
277            },
278        );
279        g.outputs.insert("scale".to_string(), "scale".to_string());
280
281        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
282        assert!(text.contains("scale: f32[1] @scalar(3.5);"));
283    }
284
285    #[test]
286    fn test_serialize_multi_output_node() {
287        let mut g = new_graph_json();
288        g.name = Some("test".to_string());
289        g.inputs.insert(
290            "x".to_string(),
291            OperandDesc {
292                data_type: DataType::Float32,
293                shape: to_dimension_vector(&[10]),
294            },
295        );
296        g.nodes.push(Node {
297            id: "a".to_string(),
298            op: "split".to_string(),
299            inputs: vec!["x".to_string()],
300            options: serde_json::Map::new(),
301            outputs: Some(vec!["a".to_string(), "b".to_string()]),
302        });
303        g.outputs.insert("a".to_string(), "a".to_string());
304
305        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
306        assert!(text.contains("[a, b] = split(x);"));
307    }
308
309    #[test]
310    fn test_serialize_node_options() {
311        let mut g = new_graph_json();
312        g.name = Some("test".to_string());
313        g.inputs.insert(
314            "x".to_string(),
315            OperandDesc {
316                data_type: DataType::Float32,
317                shape: to_dimension_vector(&[1, 10]),
318            },
319        );
320
321        let mut options = serde_json::Map::new();
322        options.insert("axis".to_string(), serde_json::json!(1));
323        options.insert("keepdims".to_string(), serde_json::json!(true));
324
325        g.nodes.push(Node {
326            id: "result".to_string(),
327            op: "softmax".to_string(),
328            inputs: vec!["x".to_string()],
329            options,
330            outputs: None,
331        });
332        g.outputs.insert("result".to_string(), "result".to_string());
333
334        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
335        assert!(text.contains("softmax(x,"));
336        assert!(text.contains("axis=1"));
337        assert!(text.contains("keepdims=true"));
338    }
339
340    #[test]
341    fn test_serialize_various_dtypes() {
342        let mut g = new_graph_json();
343        g.name = Some("test".to_string());
344
345        let dtypes = vec![
346            ("f32_input", DataType::Float32),
347            ("f16_input", DataType::Float16),
348            ("i32_input", DataType::Int32),
349            ("u32_input", DataType::Uint32),
350            ("i64_input", DataType::Int64),
351            ("u64_input", DataType::Uint64),
352            ("i8_input", DataType::Int8),
353            ("u8_input", DataType::Uint8),
354        ];
355
356        for (name, dtype) in dtypes {
357            g.inputs.insert(
358                name.to_string(),
359                OperandDesc {
360                    data_type: dtype,
361                    shape: to_dimension_vector(&[1]),
362                },
363            );
364        }
365        g.outputs
366            .insert("f32_input".to_string(), "f32_input".to_string());
367
368        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
369        assert!(text.contains("f32_input: f32[1];"));
370        assert!(text.contains("f16_input: f16[1];"));
371        assert!(text.contains("i32_input: i32[1];"));
372        assert!(text.contains("u32_input: u32[1];"));
373        assert!(text.contains("i64_input: i64[1];"));
374        assert!(text.contains("u64_input: u64[1];"));
375        assert!(text.contains("i8_input: i8[1];"));
376        assert!(text.contains("u8_input: u8[1];"));
377    }
378
379    #[test]
380    fn test_roundtrip() {
381        let input = r#"
382webnn_graph "resnet_head" v1 {
383  inputs {
384    x: f32[1, 2048];
385  }
386  consts {
387    W: f32[2048, 1000] @weights("W");
388    b: f32[1000] @weights("b");
389  }
390  nodes {
391    logits0 = matmul(x, W);
392    logits = add(logits0, b);
393    probs = softmax(logits, axis=1);
394  }
395  outputs { probs; }
396}
397"#;
398        // Parse the text
399        let graph = parse_wg_text(input).unwrap();
400
401        // Serialize back to text
402        let serialized = serialize_graph_to_wg_text(&graph, SerializeOptions::default()).unwrap();
403
404        // Parse again to verify structure is preserved
405        let graph2 = parse_wg_text(&serialized).unwrap();
406
407        // Verify key properties
408        assert_eq!(graph.name, graph2.name);
409        assert_eq!(graph.inputs.len(), graph2.inputs.len());
410        assert_eq!(graph.consts.len(), graph2.consts.len());
411        assert_eq!(graph.nodes.len(), graph2.nodes.len());
412        assert_eq!(graph.outputs.len(), graph2.outputs.len());
413    }
414
415    #[test]
416    fn test_default_graph_name() {
417        let mut g = new_graph_json();
418        // No name set (None)
419        g.outputs.insert("x".to_string(), "x".to_string());
420
421        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
422        assert!(text.contains(&format!("webnn_graph \"graph\" v{}", g.version)));
423    }
424
425    #[test]
426    fn test_string_escaping() {
427        let mut g = new_graph_json();
428        g.name = Some("test\"with\\quotes".to_string());
429        g.outputs.insert("x".to_string(), "x".to_string());
430
431        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
432        assert!(text.contains(&format!(
433            "webnn_graph \"test\\\"with\\\\quotes\" v{}",
434            g.version
435        )));
436    }
437
438    #[test]
439    fn test_value_types() {
440        let mut g = new_graph_json();
441        g.name = Some("test".to_string());
442        g.inputs.insert(
443            "x".to_string(),
444            OperandDesc {
445                data_type: DataType::Float32,
446                shape: to_dimension_vector(&[1]),
447            },
448        );
449
450        let mut options = serde_json::Map::new();
451        options.insert("int_val".to_string(), serde_json::json!(42));
452        options.insert("float_val".to_string(), serde_json::json!(3.5));
453        options.insert("bool_val".to_string(), serde_json::json!(true));
454        options.insert("null_val".to_string(), serde_json::json!(null));
455        options.insert("array_val".to_string(), serde_json::json!([1, 2, 3]));
456
457        g.nodes.push(Node {
458            id: "result".to_string(),
459            op: "test_op".to_string(),
460            inputs: vec!["x".to_string()],
461            options,
462            outputs: None,
463        });
464        g.outputs.insert("result".to_string(), "result".to_string());
465
466        let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
467        assert!(text.contains("int_val=42"));
468        assert!(text.contains("float_val=3.5"));
469        assert!(text.contains("bool_val=true"));
470        assert!(text.contains("null_val=null"));
471        assert!(text.contains("array_val=[1, 2, 3]"));
472    }
473}