webnn_graph/
serialize.rs

1use crate::ast::{ConstInit, GraphJson, Node};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum SerializeError {
6    #[error("invalid format: {0}")]
7    InvalidFormat(String),
8    #[error("unsupported version: {0}")]
9    UnsupportedVersion(u32),
10}
11
12pub fn serialize_graph_to_wg_text(graph: &GraphJson) -> Result<String, SerializeError> {
13    let mut output = String::new();
14
15    // Validate format
16    if graph.format != "webnn-graph-json" {
17        return Err(SerializeError::InvalidFormat(graph.format.clone()));
18    }
19    if graph.version != 1 {
20        return Err(SerializeError::UnsupportedVersion(graph.version));
21    }
22
23    // Header
24    let name = graph.name.as_deref().unwrap_or("graph");
25    output.push_str(&format!("webnn_graph \"{}\" v1 {{\n", escape_string(name)));
26
27    // Inputs block
28    if !graph.inputs.is_empty() {
29        output.push_str("  inputs {\n");
30        for (name, desc) in &graph.inputs {
31            let dtype = desc.data_type.to_wg_text();
32            let shape = serialize_shape(&desc.shape);
33            output.push_str(&format!("    {}: {}{};\n", name, dtype, shape));
34        }
35        output.push_str("  }\n\n");
36    }
37
38    // Consts block
39    if !graph.consts.is_empty() {
40        output.push_str("  consts {\n");
41        for (name, const_decl) in &graph.consts {
42            let dtype = const_decl.data_type.to_wg_text();
43            let shape = serialize_shape(&const_decl.shape);
44            let annotation = serialize_const_init(&const_decl.init);
45            output.push_str(&format!("    {}: {}{}{}", name, dtype, shape, annotation));
46            output.push_str(";\n");
47        }
48        output.push_str("  }\n\n");
49    }
50
51    // Nodes block
52    if !graph.nodes.is_empty() {
53        output.push_str("  nodes {\n");
54        for node in &graph.nodes {
55            output.push_str(&format!("    {}\n", serialize_node(node)));
56        }
57        output.push_str("  }\n\n");
58    }
59
60    // Outputs block
61    output.push_str("  outputs {");
62    if !graph.outputs.is_empty() {
63        let outputs: Vec<String> = graph.outputs.keys().map(|k| format!(" {};", k)).collect();
64        output.push_str(&outputs.join(""));
65        output.push(' ');
66    }
67    output.push_str("}\n");
68
69    output.push_str("}\n");
70    Ok(output)
71}
72
73fn serialize_shape(shape: &[u32]) -> String {
74    let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
75    format!("[{}]", dims.join(", "))
76}
77
78fn serialize_const_init(init: &ConstInit) -> String {
79    match init {
80        ConstInit::Weights { r#ref } => {
81            format!(" @weights(\"{}\")", escape_string(r#ref))
82        }
83        ConstInit::Scalar { value } => {
84            format!(" @scalar({})", serialize_json_value(value))
85        }
86        ConstInit::InlineBytes { bytes } => {
87            // InlineBytes typically doesn't have a special annotation in the text format
88            // We'll serialize it without annotation (similar to default weights)
89            format!(" @bytes({:?})", bytes)
90        }
91    }
92}
93
94fn serialize_node(node: &Node) -> String {
95    let call = serialize_call(&node.op, &node.inputs, &node.options);
96
97    if let Some(outputs) = &node.outputs {
98        // Multi-output node: [a, b] = op(...)
99        let out_list = outputs.join(", ");
100        format!("[{}] = {};", out_list, call)
101    } else {
102        // Single output node: id = op(...)
103        format!("{} = {};", node.id, call)
104    }
105}
106
107fn serialize_call(
108    op: &str,
109    inputs: &[String],
110    options: &serde_json::Map<String, serde_json::Value>,
111) -> String {
112    let mut args = Vec::new();
113
114    // Add positional inputs
115    for input in inputs {
116        args.push(input.clone());
117    }
118
119    // Add named options
120    for (key, value) in options {
121        args.push(format!("{}={}", key, serialize_json_value(value)));
122    }
123
124    format!("{}({})", op, args.join(", "))
125}
126
127fn serialize_json_value(value: &serde_json::Value) -> String {
128    match value {
129        serde_json::Value::Null => "null".to_string(),
130        serde_json::Value::Bool(b) => b.to_string(),
131        serde_json::Value::Number(n) => n.to_string(),
132        serde_json::Value::String(s) => format!("\"{}\"", escape_string(s)),
133        serde_json::Value::Array(arr) => {
134            let items: Vec<String> = arr.iter().map(serialize_json_value).collect();
135            format!("[{}]", items.join(", "))
136        }
137        serde_json::Value::Object(_) => {
138            // Objects are not typically used in the WG text format
139            value.to_string()
140        }
141    }
142}
143
144fn escape_string(s: &str) -> String {
145    s.replace('\\', "\\\\").replace('"', "\\\"")
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, Node, OperandDesc};
152    use crate::parser::parse_wg_text;
153
154    #[test]
155    fn test_serialize_simple_graph() {
156        let mut g = new_graph_json();
157        g.name = Some("test".to_string());
158        g.inputs.insert(
159            "x".to_string(),
160            OperandDesc {
161                data_type: DataType::Float32,
162                shape: vec![1, 10],
163            },
164        );
165        g.nodes.push(Node {
166            id: "result".to_string(),
167            op: "relu".to_string(),
168            inputs: vec!["x".to_string()],
169            options: serde_json::Map::new(),
170            outputs: None,
171        });
172        g.outputs.insert("result".to_string(), "result".to_string());
173
174        let text = serialize_graph_to_wg_text(&g).unwrap();
175        assert!(text.contains("webnn_graph \"test\" v1"));
176        assert!(text.contains("inputs {"));
177        assert!(text.contains("x: f32[1, 10];"));
178        assert!(text.contains("nodes {"));
179        assert!(text.contains("result = relu(x);"));
180        assert!(text.contains("outputs { result; }"));
181    }
182
183    #[test]
184    fn test_serialize_weights_annotation() {
185        let mut g = new_graph_json();
186        g.name = Some("test".to_string());
187        g.consts.insert(
188            "W".to_string(),
189            ConstDecl {
190                data_type: DataType::Float32,
191                shape: vec![10, 5],
192                init: ConstInit::Weights {
193                    r#ref: "W".to_string(),
194                },
195            },
196        );
197        g.outputs.insert("W".to_string(), "W".to_string());
198
199        let text = serialize_graph_to_wg_text(&g).unwrap();
200        assert!(text.contains("W: f32[10, 5] @weights(\"W\");"));
201    }
202
203    #[test]
204    fn test_serialize_scalar_annotation() {
205        let mut g = new_graph_json();
206        g.name = Some("test".to_string());
207        g.consts.insert(
208            "scale".to_string(),
209            ConstDecl {
210                data_type: DataType::Float32,
211                shape: vec![1],
212                init: ConstInit::Scalar {
213                    value: serde_json::json!(3.5),
214                },
215            },
216        );
217        g.outputs.insert("scale".to_string(), "scale".to_string());
218
219        let text = serialize_graph_to_wg_text(&g).unwrap();
220        assert!(text.contains("scale: f32[1] @scalar(3.5);"));
221    }
222
223    #[test]
224    fn test_serialize_multi_output_node() {
225        let mut g = new_graph_json();
226        g.name = Some("test".to_string());
227        g.inputs.insert(
228            "x".to_string(),
229            OperandDesc {
230                data_type: DataType::Float32,
231                shape: vec![10],
232            },
233        );
234        g.nodes.push(Node {
235            id: "a".to_string(),
236            op: "split".to_string(),
237            inputs: vec!["x".to_string()],
238            options: serde_json::Map::new(),
239            outputs: Some(vec!["a".to_string(), "b".to_string()]),
240        });
241        g.outputs.insert("a".to_string(), "a".to_string());
242
243        let text = serialize_graph_to_wg_text(&g).unwrap();
244        assert!(text.contains("[a, b] = split(x);"));
245    }
246
247    #[test]
248    fn test_serialize_node_options() {
249        let mut g = new_graph_json();
250        g.name = Some("test".to_string());
251        g.inputs.insert(
252            "x".to_string(),
253            OperandDesc {
254                data_type: DataType::Float32,
255                shape: vec![1, 10],
256            },
257        );
258
259        let mut options = serde_json::Map::new();
260        options.insert("axis".to_string(), serde_json::json!(1));
261        options.insert("keepdims".to_string(), serde_json::json!(true));
262
263        g.nodes.push(Node {
264            id: "result".to_string(),
265            op: "softmax".to_string(),
266            inputs: vec!["x".to_string()],
267            options,
268            outputs: None,
269        });
270        g.outputs.insert("result".to_string(), "result".to_string());
271
272        let text = serialize_graph_to_wg_text(&g).unwrap();
273        assert!(text.contains("softmax(x,"));
274        assert!(text.contains("axis=1"));
275        assert!(text.contains("keepdims=true"));
276    }
277
278    #[test]
279    fn test_serialize_various_dtypes() {
280        let mut g = new_graph_json();
281        g.name = Some("test".to_string());
282
283        let dtypes = vec![
284            ("f32_input", DataType::Float32),
285            ("f16_input", DataType::Float16),
286            ("i32_input", DataType::Int32),
287            ("u32_input", DataType::Uint32),
288            ("i64_input", DataType::Int64),
289            ("u64_input", DataType::Uint64),
290            ("i8_input", DataType::Int8),
291            ("u8_input", DataType::Uint8),
292        ];
293
294        for (name, dtype) in dtypes {
295            g.inputs.insert(
296                name.to_string(),
297                OperandDesc {
298                    data_type: dtype,
299                    shape: vec![1],
300                },
301            );
302        }
303        g.outputs
304            .insert("f32_input".to_string(), "f32_input".to_string());
305
306        let text = serialize_graph_to_wg_text(&g).unwrap();
307        assert!(text.contains("f32_input: f32[1];"));
308        assert!(text.contains("f16_input: f16[1];"));
309        assert!(text.contains("i32_input: i32[1];"));
310        assert!(text.contains("u32_input: u32[1];"));
311        assert!(text.contains("i64_input: i64[1];"));
312        assert!(text.contains("u64_input: u64[1];"));
313        assert!(text.contains("i8_input: i8[1];"));
314        assert!(text.contains("u8_input: u8[1];"));
315    }
316
317    #[test]
318    fn test_roundtrip() {
319        let input = r#"
320webnn_graph "resnet_head" v1 {
321  inputs {
322    x: f32[1, 2048];
323  }
324  consts {
325    W: f32[2048, 1000] @weights("W");
326    b: f32[1000] @weights("b");
327  }
328  nodes {
329    logits0 = matmul(x, W);
330    logits = add(logits0, b);
331    probs = softmax(logits, axis=1);
332  }
333  outputs { probs; }
334}
335"#;
336        // Parse the text
337        let graph = parse_wg_text(input).unwrap();
338
339        // Serialize back to text
340        let serialized = serialize_graph_to_wg_text(&graph).unwrap();
341
342        // Parse again to verify structure is preserved
343        let graph2 = parse_wg_text(&serialized).unwrap();
344
345        // Verify key properties
346        assert_eq!(graph.name, graph2.name);
347        assert_eq!(graph.inputs.len(), graph2.inputs.len());
348        assert_eq!(graph.consts.len(), graph2.consts.len());
349        assert_eq!(graph.nodes.len(), graph2.nodes.len());
350        assert_eq!(graph.outputs.len(), graph2.outputs.len());
351    }
352
353    #[test]
354    fn test_default_graph_name() {
355        let mut g = new_graph_json();
356        // No name set (None)
357        g.outputs.insert("x".to_string(), "x".to_string());
358
359        let text = serialize_graph_to_wg_text(&g).unwrap();
360        assert!(text.contains("webnn_graph \"graph\" v1"));
361    }
362
363    #[test]
364    fn test_string_escaping() {
365        let mut g = new_graph_json();
366        g.name = Some("test\"with\\quotes".to_string());
367        g.outputs.insert("x".to_string(), "x".to_string());
368
369        let text = serialize_graph_to_wg_text(&g).unwrap();
370        assert!(text.contains("webnn_graph \"test\\\"with\\\\quotes\" v1"));
371    }
372
373    #[test]
374    fn test_value_types() {
375        let mut g = new_graph_json();
376        g.name = Some("test".to_string());
377        g.inputs.insert(
378            "x".to_string(),
379            OperandDesc {
380                data_type: DataType::Float32,
381                shape: vec![1],
382            },
383        );
384
385        let mut options = serde_json::Map::new();
386        options.insert("int_val".to_string(), serde_json::json!(42));
387        options.insert("float_val".to_string(), serde_json::json!(3.5));
388        options.insert("bool_val".to_string(), serde_json::json!(true));
389        options.insert("null_val".to_string(), serde_json::json!(null));
390        options.insert("array_val".to_string(), serde_json::json!([1, 2, 3]));
391
392        g.nodes.push(Node {
393            id: "result".to_string(),
394            op: "test_op".to_string(),
395            inputs: vec!["x".to_string()],
396            options,
397            outputs: None,
398        });
399        g.outputs.insert("result".to_string(), "result".to_string());
400
401        let text = serialize_graph_to_wg_text(&g).unwrap();
402        assert!(text.contains("int_val=42"));
403        assert!(text.contains("float_val=3.5"));
404        assert!(text.contains("bool_val=true"));
405        assert!(text.contains("null_val=null"));
406        assert!(text.contains("array_val=[1, 2, 3]"));
407    }
408}