webnn_graph/
parser.rs

1use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, GraphJson, Node, OperandDesc};
2use pest::iterators::Pair;
3use pest::Parser;
4use pest_derive::Parser;
5use serde_json::{Map, Value};
6use std::collections::BTreeMap;
7use thiserror::Error;
8
9#[derive(Parser)]
10#[grammar = "wg.pest"]
11struct WGParser;
12
13#[derive(Debug, Error)]
14pub enum ParseError {
15    #[error("parse error: {0}")]
16    Pest(Box<pest::error::Error<Rule>>),
17    #[error("invalid dtype: {0}")]
18    BadDType(String),
19    #[error("internal error: {0}")]
20    Internal(String),
21}
22
23impl From<pest::error::Error<Rule>> for ParseError {
24    fn from(err: pest::error::Error<Rule>) -> Self {
25        ParseError::Pest(Box::new(err))
26    }
27}
28
29type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
30
31pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
32    let mut pairs = WGParser::parse(Rule::file, input)?;
33    let file = pairs
34        .next()
35        .ok_or_else(|| ParseError::Internal("missing file".into()))?;
36
37    let mut g = new_graph_json();
38    let mut nodes: Vec<Node> = Vec::new();
39
40    for p in file.into_inner() {
41        match p.as_rule() {
42            Rule::header => {
43                // Extract graph name from header
44                for inner in p.into_inner() {
45                    if inner.as_rule() == Rule::string {
46                        g.name = Some(unquote(inner.as_str()));
47                        break;
48                    }
49                }
50            }
51            Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
52            Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
53            Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
54            Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
55            _ => {}
56        }
57    }
58
59    g.nodes = nodes;
60    Ok(g)
61}
62
63fn parse_inputs_block(
64    p: Pair<Rule>,
65    out: &mut BTreeMap<String, OperandDesc>,
66) -> Result<(), ParseError> {
67    for inner in p.into_inner() {
68        if inner.as_rule() == Rule::input_decl {
69            let mut it = inner.into_inner();
70            let name = it.next().unwrap().as_str().to_string();
71            let (dt, shape) = parse_ty(it.next().unwrap())?;
72            out.insert(
73                name,
74                OperandDesc {
75                    data_type: dt,
76                    shape,
77                },
78            );
79        }
80    }
81    Ok(())
82}
83
84fn parse_consts_block(
85    p: Pair<Rule>,
86    out: &mut BTreeMap<String, ConstDecl>,
87) -> Result<(), ParseError> {
88    for inner in p.into_inner() {
89        if inner.as_rule() == Rule::const_decl {
90            let mut it = inner.into_inner();
91            let name = it.next().unwrap().as_str().to_string();
92            let (dt, shape) = parse_ty(it.next().unwrap())?;
93
94            let mut init: Option<ConstInit> = None;
95            for ann in it {
96                if ann.as_rule() == Rule::const_annot {
97                    let text = ann.as_str();
98                    if text.starts_with("@weights") {
99                        let s = ann
100                            .into_inner()
101                            .find(|p| p.as_rule() == Rule::string)
102                            .map(|p| unquote(p.as_str()))
103                            .unwrap_or_else(|| name.clone());
104                        init = Some(ConstInit::Weights { r#ref: s });
105                    } else if text.starts_with("@scalar") {
106                        let n = ann
107                            .into_inner()
108                            .find(|p| p.as_rule() == Rule::number)
109                            .map(|p| parse_number_value(p.as_str()))
110                            .unwrap_or(Value::Null);
111                        init = Some(ConstInit::Scalar { value: n });
112                    }
113                }
114            }
115
116            let init = init.unwrap_or(ConstInit::Weights {
117                r#ref: name.clone(),
118            });
119            out.insert(
120                name,
121                ConstDecl {
122                    data_type: dt,
123                    shape,
124                    init,
125                },
126            );
127        }
128    }
129    Ok(())
130}
131
132fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
133    for inner in p.into_inner() {
134        if inner.as_rule() != Rule::stmt {
135            continue;
136        }
137        let stmt = inner.into_inner().next().unwrap();
138        match stmt.as_rule() {
139            Rule::assign => out.push(parse_assign(stmt)?),
140            Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
141            _ => {}
142        }
143    }
144    Ok(())
145}
146
147fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
148    let mut it = p.into_inner();
149    let id = it.next().unwrap().as_str().to_string();
150    let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
151    Ok(Node {
152        id,
153        op,
154        inputs,
155        options,
156        outputs,
157    })
158}
159
160fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
161    let mut it = p.into_inner();
162    let mut outs: Vec<String> = Vec::new();
163
164    // first items are idents inside [...]
165    // We receive them as a flat sequence of ident tokens due to grammar.
166    // Collect until we hit expr.
167    while let Some(next) = it.peek() {
168        if next.as_rule() == Rule::expr {
169            break;
170        }
171        let t = it.next().unwrap();
172        if t.as_rule() == Rule::ident {
173            outs.push(t.as_str().to_string());
174        }
175    }
176
177    let expr = it
178        .next()
179        .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
180    let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
181    // Use the first output name as the node id for uniqueness; keep real outputs in Node.outputs.
182    let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
183    Ok(Node {
184        id,
185        op,
186        inputs,
187        options,
188        outputs: Some(outs),
189    })
190}
191
192fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
193    match p.as_rule() {
194        Rule::expr => parse_expr(p.into_inner().next().unwrap()),
195        Rule::call => parse_call(p),
196        Rule::ident => Ok((
197            String::new(),
198            vec![p.as_str().to_string()],
199            Map::new(),
200            None,
201        )),
202        _ => Err(ParseError::Internal(format!(
203            "unexpected expr rule: {:?}",
204            p.as_rule()
205        ))),
206    }
207}
208
209fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
210    let mut it = p.into_inner();
211    let op = it.next().unwrap().as_str().to_string();
212    let mut inputs: Vec<String> = Vec::new();
213    let mut options: Map<String, Value> = Map::new();
214
215    if let Some(args) = it.next() {
216        if args.as_rule() == Rule::args {
217            for arg in args.into_inner() {
218                if arg.as_rule() != Rule::arg {
219                    continue;
220                }
221                let mut a = arg.into_inner().peekable();
222
223                // Check if this is a named argument: ident '=' value
224                let first = match a.next() {
225                    Some(f) => f,
226                    None => continue,
227                };
228
229                if first.as_rule() == Rule::ident
230                    && a.peek().is_some()
231                    && a.peek().unwrap().as_rule() == Rule::value
232                {
233                    // Named argument
234                    let key = first.as_str().to_string();
235                    let val = parse_value(a.next().unwrap())?;
236                    options.insert(key, val);
237                } else {
238                    // Positional argument
239                    match first.as_rule() {
240                        Rule::value => {
241                            let v = parse_value(first)?;
242                            if let Value::String(s) = v {
243                                inputs.push(s);
244                            } else if let Some(sym) = v.as_str() {
245                                inputs.push(sym.to_string());
246                            }
247                        }
248                        Rule::ident => inputs.push(first.as_str().to_string()),
249                        _ => {}
250                    }
251                }
252            }
253        }
254    }
255
256    Ok((op, inputs, options, None))
257}
258
259fn parse_outputs_block(
260    p: Pair<Rule>,
261    out: &mut BTreeMap<String, String>,
262) -> Result<(), ParseError> {
263    // WG: outputs { probs }  OR outputs { a,b; }
264    // We'll map each output name to itself.
265    for inner in p.into_inner() {
266        if inner.as_rule() == Rule::output_item {
267            for item in inner.into_inner() {
268                if item.as_rule() == Rule::ident {
269                    let name = item.as_str().to_string();
270                    out.insert(name.clone(), name);
271                }
272            }
273        }
274    }
275    Ok(())
276}
277
278fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<u32>), ParseError> {
279    let mut it = p.into_inner();
280    let dt_s = it.next().unwrap().as_str();
281    let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
282    let shape = parse_shape(it.next().unwrap())?;
283    Ok((dt, shape))
284}
285
286fn parse_shape(p: Pair<Rule>) -> Result<Vec<u32>, ParseError> {
287    let mut shape = Vec::new();
288    for inner in p.into_inner() {
289        if inner.as_rule() == Rule::int {
290            let v: u32 = inner
291                .as_str()
292                .parse()
293                .map_err(|_| ParseError::Internal("bad int".into()))?;
294            shape.push(v);
295        }
296    }
297    Ok(shape)
298}
299
300fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
301    match p.as_rule() {
302        Rule::value => parse_value(p.into_inner().next().unwrap()),
303        Rule::literal => parse_value(p.into_inner().next().unwrap()),
304        Rule::string => Ok(Value::String(unquote(p.as_str()))),
305        Rule::number => Ok(parse_number_value(p.as_str())),
306        Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
307        Rule::null => Ok(Value::Null),
308        Rule::array => {
309            let mut arr = Vec::new();
310            for inner in p.into_inner() {
311                if inner.as_rule() == Rule::value {
312                    arr.push(parse_value(inner)?);
313                }
314            }
315            Ok(Value::Array(arr))
316        }
317        Rule::ident => Ok(Value::String(p.as_str().to_string())),
318        _ => Err(ParseError::Internal(format!(
319            "unexpected value rule: {:?}",
320            p.as_rule()
321        ))),
322    }
323}
324
325fn parse_number_value(s: &str) -> Value {
326    // Prefer i64 when exact, otherwise f64.
327    if !s.contains('.') && !s.contains('e') && !s.contains('E') {
328        if let Ok(i) = s.parse::<i64>() {
329            return Value::Number(i.into());
330        }
331    }
332    Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
333}
334
335fn unquote(s: &str) -> String {
336    let mut t = s.to_string();
337    if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
338        t.remove(0);
339        t.pop();
340    }
341    t.replace("\\\"", "\"").replace("\\\\", "\\")
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_parse_simple_graph() {
350        let input = r#"
351webnn_graph "test" v1 {
352  inputs {
353    x: f32[1, 10];
354  }
355  consts {
356    W: f32[10, 5] @weights("W");
357  }
358  nodes {
359    result = matmul(x, W);
360  }
361  outputs { result; }
362}
363"#;
364        let graph = parse_wg_text(input).unwrap();
365        assert_eq!(graph.format, "webnn-graph-json");
366        assert_eq!(graph.version, 1);
367        assert_eq!(graph.inputs.len(), 1);
368        assert_eq!(graph.consts.len(), 1);
369        assert_eq!(graph.nodes.len(), 1);
370        assert_eq!(graph.outputs.len(), 1);
371    }
372
373    #[test]
374    fn test_parse_inputs() {
375        let input = r#"
376webnn_graph "test" v1 {
377  inputs {
378    x: f32[1, 10];
379    y: i32[5];
380  }
381  nodes {}
382  outputs { x; }
383}
384"#;
385        let graph = parse_wg_text(input).unwrap();
386        assert_eq!(graph.inputs.len(), 2);
387        assert!(graph.inputs.contains_key("x"));
388        assert!(graph.inputs.contains_key("y"));
389
390        let x_desc = &graph.inputs["x"];
391        assert_eq!(x_desc.data_type, DataType::Float32);
392        assert_eq!(x_desc.shape, vec![1, 10]);
393
394        let y_desc = &graph.inputs["y"];
395        assert_eq!(y_desc.data_type, DataType::Int32);
396        assert_eq!(y_desc.shape, vec![5]);
397    }
398
399    #[test]
400    fn test_parse_consts_with_weights() {
401        let input = r#"
402webnn_graph "test" v1 {
403  inputs { x: f32[1]; }
404  consts {
405    W: f32[10, 5] @weights("W");
406    b: f32[5] @weights("bias");
407  }
408  nodes {}
409  outputs { x; }
410}
411"#;
412        let graph = parse_wg_text(input).unwrap();
413        assert_eq!(graph.consts.len(), 2);
414
415        let w = &graph.consts["W"];
416        assert_eq!(w.data_type, DataType::Float32);
417        assert_eq!(w.shape, vec![10, 5]);
418        assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
419
420        let b = &graph.consts["b"];
421        assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
422    }
423
424    #[test]
425    fn test_parse_consts_with_scalar() {
426        let input = r#"
427webnn_graph "test" v1 {
428  inputs { x: f32[1]; }
429  consts {
430    scale: f32[1] @scalar(2.5);
431  }
432  nodes {}
433  outputs { x; }
434}
435"#;
436        let graph = parse_wg_text(input).unwrap();
437        let scale = &graph.consts["scale"];
438        match &scale.init {
439            ConstInit::Scalar { value } => {
440                assert_eq!(value.as_f64().unwrap(), 2.5);
441            }
442            _ => panic!("Expected scalar init"),
443        }
444    }
445
446    #[test]
447    fn test_parse_nodes() {
448        let input = r#"
449webnn_graph "test" v1 {
450  inputs { x: f32[1, 2048]; }
451  consts { W: f32[2048, 1000] @weights("W"); }
452  nodes {
453    result = matmul(x, W);
454  }
455  outputs { result; }
456}
457"#;
458        let graph = parse_wg_text(input).unwrap();
459        assert_eq!(graph.nodes.len(), 1);
460
461        let node = &graph.nodes[0];
462        assert_eq!(node.id, "result");
463        assert_eq!(node.op, "matmul");
464        assert_eq!(node.inputs, vec!["x", "W"]);
465        assert!(node.options.is_empty());
466    }
467
468    #[test]
469    fn test_parse_nodes_with_options() {
470        let input = r#"
471webnn_graph "test" v1 {
472  inputs { x: f32[1, 10]; }
473  nodes {
474    result = softmax(x, axis=1);
475  }
476  outputs { result; }
477}
478"#;
479        let graph = parse_wg_text(input).unwrap();
480        let node = &graph.nodes[0];
481        assert_eq!(node.op, "softmax");
482        assert_eq!(node.inputs, vec!["x"]);
483        assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
484    }
485
486    #[test]
487    fn test_parse_multi_assign() {
488        let input = r#"
489webnn_graph "test" v1 {
490  inputs { x: f32[10]; }
491  nodes {
492    [a, b] = split(x);
493  }
494  outputs { a; }
495}
496"#;
497        let graph = parse_wg_text(input).unwrap();
498        let node = &graph.nodes[0];
499        assert_eq!(node.id, "a");
500        assert_eq!(node.op, "split");
501        assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
502    }
503
504    #[test]
505    fn test_parse_outputs() {
506        let input = r#"
507webnn_graph "test" v1 {
508  inputs { x: f32[1]; }
509  nodes {
510    a = relu(x);
511    b = sigmoid(x);
512  }
513  outputs { a; b; }
514}
515"#;
516        let graph = parse_wg_text(input).unwrap();
517        assert_eq!(graph.outputs.len(), 2);
518        assert_eq!(graph.outputs.get("a").unwrap(), "a");
519        assert_eq!(graph.outputs.get("b").unwrap(), "b");
520    }
521
522    #[test]
523    fn test_parse_invalid_dtype() {
524        let input = r#"
525webnn_graph "test" v1 {
526  inputs { x: float32[1]; }
527  nodes {}
528  outputs { x; }
529}
530"#;
531        let result = parse_wg_text(input);
532        assert!(result.is_err());
533        // The pest parser should fail because "float32" doesn't match the dtype rule
534        match result {
535            Err(ParseError::Pest(_)) => {}
536            Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
537            Ok(_) => panic!("Expected error but parsing succeeded"),
538        }
539    }
540
541    #[test]
542    fn test_unquote() {
543        assert_eq!(unquote(r#""hello""#), "hello");
544        assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
545        assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
546        assert_eq!(unquote("no_quotes"), "no_quotes");
547    }
548
549    #[test]
550    fn test_parse_number_value() {
551        let int_val = parse_number_value("42");
552        assert_eq!(int_val.as_i64().unwrap(), 42);
553
554        let float_val = parse_number_value("3.12");
555        assert_eq!(float_val.as_f64().unwrap(), 3.12);
556
557        let sci_val = parse_number_value("1e-3");
558        assert_eq!(sci_val.as_f64().unwrap(), 0.001);
559    }
560}