1use crate::ast::{
2 new_graph_json, ConstDecl, ConstInit, DataType, Dimension, DynamicDimension, GraphJson, Node,
3 OperandDesc,
4};
5use pest::iterators::Pair;
6use pest::Parser;
7use pest_derive::Parser;
8use serde_json::{Map, Value};
9use std::collections::BTreeMap;
10use thiserror::Error;
11
12#[derive(Parser)]
13#[grammar = "wg.pest"]
14struct WGParser;
15
16#[derive(Debug, Error)]
17pub enum ParseError {
18 #[error("parse error: {0}")]
19 Pest(Box<pest::error::Error<Rule>>),
20 #[error("invalid dtype: {0}")]
21 BadDType(String),
22 #[error("internal error: {0}")]
23 Internal(String),
24 #[error("constant shapes must be static")]
25 DynamicConstShape,
26}
27
28impl From<pest::error::Error<Rule>> for ParseError {
29 fn from(err: pest::error::Error<Rule>) -> Self {
30 ParseError::Pest(Box::new(err))
31 }
32}
33
34type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
35
36pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
37 let mut pairs = WGParser::parse(Rule::file, input)?;
38 let file = pairs
39 .next()
40 .ok_or_else(|| ParseError::Internal("missing file".into()))?;
41
42 let mut g = new_graph_json();
43 let mut nodes: Vec<Node> = Vec::new();
44
45 for p in file.into_inner() {
46 match p.as_rule() {
47 Rule::header => {
48 for inner in p.into_inner() {
50 match inner.as_rule() {
51 Rule::string => g.name = Some(unquote(inner.as_str())),
52 Rule::int => {
53 let version: u32 = inner
54 .as_str()
55 .parse()
56 .map_err(|e| ParseError::Internal(format!("bad version: {}", e)))?;
57 g.version = version;
58 }
59 Rule::quantized => g.quantized = true,
60 _ => {}
61 }
62 }
63 }
64 Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
65 Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
66 Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
67 Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
68 _ => {}
69 }
70 }
71
72 g.nodes = nodes;
73 Ok(g)
74}
75
76fn parse_inputs_block(
77 p: Pair<Rule>,
78 out: &mut BTreeMap<String, OperandDesc>,
79) -> Result<(), ParseError> {
80 for inner in p.into_inner() {
81 if inner.as_rule() == Rule::input_decl {
82 let mut it = inner.into_inner();
83 let name = it.next().unwrap().as_str().to_string();
84 let (dt, shape) = parse_ty(it.next().unwrap())?;
85 out.insert(
86 name,
87 OperandDesc {
88 data_type: dt,
89 shape,
90 },
91 );
92 }
93 }
94 Ok(())
95}
96
97fn parse_consts_block(
98 p: Pair<Rule>,
99 out: &mut BTreeMap<String, ConstDecl>,
100) -> Result<(), ParseError> {
101 for inner in p.into_inner() {
102 if inner.as_rule() == Rule::const_decl {
103 let mut it = inner.into_inner();
104 let name = it.next().unwrap().as_str().to_string();
105 let (dt, shape) = parse_ty(it.next().unwrap())?;
106
107 let mut init: Option<ConstInit> = None;
108 for ann in it {
109 if ann.as_rule() == Rule::const_annot {
110 let text = ann.as_str();
111 if text.starts_with("@weights") {
112 let s = ann
113 .into_inner()
114 .find(|p| p.as_rule() == Rule::string)
115 .map(|p| unquote(p.as_str()))
116 .unwrap_or_else(|| name.clone());
117 init = Some(ConstInit::Weights { r#ref: s });
118 } else if text.starts_with("@scalar") {
119 let n = ann
120 .into_inner()
121 .find(|p| p.as_rule() == Rule::number)
122 .map(|p| parse_number_value(p.as_str()))
123 .unwrap_or(Value::Null);
124 init = Some(ConstInit::Scalar { value: n });
125 } else if text.starts_with("@bytes") {
126 let bytes = ann
127 .into_inner()
128 .find(|p| p.as_rule() == Rule::byte_array)
129 .map(|pair| {
130 pair.into_inner()
131 .filter(|p| p.as_rule() == Rule::int)
132 .filter_map(|p| p.as_str().parse::<u32>().ok())
133 .map(|v| v as u8)
134 .collect::<Vec<u8>>()
135 })
136 .unwrap_or_default();
137 init = Some(ConstInit::InlineBytes { bytes });
138 }
139 }
140 }
141
142 let init = init.unwrap_or(ConstInit::Weights {
143 r#ref: name.clone(),
144 });
145 out.insert(
146 name,
147 ConstDecl {
148 data_type: dt,
149 shape: dims_to_static_shape(&shape)?,
150 init,
151 },
152 );
153 }
154 }
155 Ok(())
156}
157
158fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
159 for inner in p.into_inner() {
160 if inner.as_rule() != Rule::stmt {
161 continue;
162 }
163 let stmt = inner.into_inner().next().unwrap();
164 match stmt.as_rule() {
165 Rule::assign => out.push(parse_assign(stmt)?),
166 Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
167 _ => {}
168 }
169 }
170 Ok(())
171}
172
173fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
174 let mut it = p.into_inner();
175 let id = it.next().unwrap().as_str().to_string();
176 let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
177 Ok(Node {
178 id,
179 op,
180 inputs,
181 options,
182 outputs,
183 })
184}
185
186fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
187 let mut it = p.into_inner();
188 let mut outs: Vec<String> = Vec::new();
189
190 while let Some(next) = it.peek() {
194 if next.as_rule() == Rule::expr {
195 break;
196 }
197 let t = it.next().unwrap();
198 if t.as_rule() == Rule::ident {
199 outs.push(t.as_str().to_string());
200 }
201 }
202
203 let expr = it
204 .next()
205 .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
206 let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
207 let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
209 Ok(Node {
210 id,
211 op,
212 inputs,
213 options,
214 outputs: Some(outs),
215 })
216}
217
218fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
219 match p.as_rule() {
220 Rule::expr => parse_expr(p.into_inner().next().unwrap()),
221 Rule::call => parse_call(p),
222 Rule::ident => Ok((
223 String::new(),
224 vec![p.as_str().to_string()],
225 Map::new(),
226 None,
227 )),
228 _ => Err(ParseError::Internal(format!(
229 "unexpected expr rule: {:?}",
230 p.as_rule()
231 ))),
232 }
233}
234
235fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
236 let mut it = p.into_inner();
237 let op = it.next().unwrap().as_str().to_string();
238 let mut inputs: Vec<String> = Vec::new();
239 let mut options: Map<String, Value> = Map::new();
240
241 let is_concat = op == "concat";
243 if is_concat {
244 crate::debug_println!("[PARSER DEBUG] Parsing concat operation");
245 }
246
247 if let Some(args) = it.next() {
248 if args.as_rule() == Rule::args {
249 for (arg_idx, arg) in args.into_inner().enumerate() {
250 if arg.as_rule() != Rule::arg {
251 continue;
252 }
253 let mut a = arg.into_inner().peekable();
254
255 let first = match a.next() {
257 Some(f) => f,
258 None => continue,
259 };
260
261 if is_concat {
262 crate::debug_println!(
263 "[PARSER DEBUG] arg[{}]: first.rule={:?}, first.as_str()={}, has_next={}",
264 arg_idx,
265 first.as_rule(),
266 first.as_str(),
267 a.peek().is_some()
268 );
269 if let Some(next) = a.peek() {
270 crate::debug_println!(
271 "[PARSER DEBUG] arg[{}]: next.rule={:?}, next.as_str()={}",
272 arg_idx,
273 next.as_rule(),
274 next.as_str()
275 );
276 }
277 }
278
279 if first.as_rule() == Rule::ident
280 && a.peek().is_some()
281 && a.peek().unwrap().as_rule() == Rule::value
282 {
283 let key = first.as_str().to_string();
285 let val = parse_value(a.next().unwrap())?;
286 if is_concat {
287 crate::debug_println!("[PARSER DEBUG] Named argument: {}={:?}", key, val);
288 }
289 options.insert(key, val);
290 } else {
291 if is_concat {
293 crate::debug_println!(
294 "[PARSER DEBUG] Positional argument: rule={:?}",
295 first.as_rule()
296 );
297 }
298 let v = parse_value(first)?;
301 match v {
302 Value::String(s) => inputs.push(s),
303 Value::Array(arr) => {
304 for item in arr {
305 match item {
306 Value::String(s) => inputs.push(s),
307 other => {
308 if let Some(s) = other.as_str() {
309 inputs.push(s.to_string());
310 }
311 }
312 }
313 }
314 }
315 other => {
316 if let Some(s) = other.as_str() {
317 inputs.push(s.to_string());
318 }
319 }
320 }
321 }
322 }
323 }
324 }
325
326 if is_concat {
327 crate::debug_println!(
328 "[PARSER DEBUG] Concat parsed: inputs={:?}, options={:?}",
329 inputs,
330 options
331 );
332 }
333
334 Ok((op, inputs, options, None))
335}
336
337fn parse_outputs_block(
338 p: Pair<Rule>,
339 out: &mut BTreeMap<String, String>,
340) -> Result<(), ParseError> {
341 for inner in p.into_inner() {
344 if inner.as_rule() == Rule::output_item {
345 for item in inner.into_inner() {
346 if item.as_rule() == Rule::ident {
347 let name = item.as_str().to_string();
348 out.insert(name.clone(), name);
349 }
350 }
351 }
352 }
353 Ok(())
354}
355
356fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<Dimension>), ParseError> {
357 let mut it = p.into_inner();
358 let dt_s = it.next().unwrap().as_str();
359 let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
360 let shape = parse_shape(it.next().unwrap())?;
361 Ok((dt, shape))
362}
363
364fn parse_shape(p: Pair<Rule>) -> Result<Vec<Dimension>, ParseError> {
365 let mut shape = Vec::new();
366 for inner in p.into_inner() {
367 if inner.as_rule() == Rule::shape_dim {
368 let item = inner
369 .into_inner()
370 .next()
371 .ok_or_else(|| ParseError::Internal("shape_dim missing inner value".to_string()))?;
372 match item.as_rule() {
373 Rule::int => {
374 let v: u32 = item
375 .as_str()
376 .parse()
377 .map_err(|_| ParseError::Internal("bad int".into()))?;
378 shape.push(Dimension::Static(v));
379 }
380 Rule::dynamic_dim => {
381 let mut it = item.into_inner();
382 let name = it
383 .next()
384 .map(|p| unquote(p.as_str()))
385 .ok_or_else(|| ParseError::Internal("dynamic_dim missing name".into()))?;
386 let max_size: u32 = it
387 .next()
388 .ok_or_else(|| ParseError::Internal("dynamic_dim missing max".into()))?
389 .as_str()
390 .parse()
391 .map_err(|_| ParseError::Internal("dynamic_dim bad max".into()))?;
392 shape.push(Dimension::Dynamic(DynamicDimension { name, max_size }));
393 }
394 _ => return Err(ParseError::Internal("unexpected shape_dim rule".into())),
395 }
396 }
397 }
398 Ok(shape)
399}
400
401fn dims_to_static_shape(shape: &[Dimension]) -> Result<Vec<u32>, ParseError> {
402 let mut out = Vec::with_capacity(shape.len());
403 for dim in shape {
404 match dim {
405 Dimension::Static(v) => out.push(*v),
406 Dimension::Dynamic(_) => return Err(ParseError::DynamicConstShape),
407 }
408 }
409 Ok(out)
410}
411
412fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
413 match p.as_rule() {
414 Rule::value => parse_value(p.into_inner().next().unwrap()),
415 Rule::literal => parse_value(p.into_inner().next().unwrap()),
416 Rule::string => Ok(Value::String(unquote(p.as_str()))),
417 Rule::number => Ok(parse_number_value(p.as_str())),
418 Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
419 Rule::null => Ok(Value::Null),
420 Rule::array => {
421 let mut arr = Vec::new();
422 for inner in p.into_inner() {
423 if inner.as_rule() == Rule::value {
424 arr.push(parse_value(inner)?);
425 }
426 }
427 Ok(Value::Array(arr))
428 }
429 Rule::object => {
430 let mut map = serde_json::Map::new();
431 for inner in p.into_inner() {
432 if inner.as_rule() == Rule::object_item {
433 let mut it = inner.into_inner();
434 let key_pair = it
435 .next()
436 .ok_or_else(|| ParseError::Internal("object key missing".into()))?;
437 let key = match key_pair.as_rule() {
438 Rule::string => unquote(key_pair.as_str()),
439 Rule::ident => key_pair.as_str().to_string(),
440 _ => {
441 return Err(ParseError::Internal(
442 "unexpected object key rule".to_string(),
443 ));
444 }
445 };
446 let value_pair = it
447 .next()
448 .ok_or_else(|| ParseError::Internal("object value missing".into()))?;
449 map.insert(key, parse_value(value_pair)?);
450 }
451 }
452 Ok(Value::Object(map))
453 }
454 Rule::ident => Ok(Value::String(p.as_str().to_string())),
455 _ => Err(ParseError::Internal(format!(
456 "unexpected value rule: {:?}",
457 p.as_rule()
458 ))),
459 }
460}
461
462fn parse_number_value(s: &str) -> Value {
463 if !s.contains('.') && !s.contains('e') && !s.contains('E') {
465 if let Ok(i) = s.parse::<i64>() {
466 return Value::Number(i.into());
467 }
468 }
469 Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
470}
471
472fn unquote(s: &str) -> String {
473 let mut t = s.to_string();
474 if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
475 t.remove(0);
476 t.pop();
477 }
478 t.replace("\\\"", "\"").replace("\\\\", "\\")
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_parse_simple_graph() {
487 let input = r#"
488webnn_graph "test" v1 {
489 inputs {
490 x: f32[1, 10];
491 }
492 consts {
493 W: f32[10, 5] @weights("W");
494 }
495 nodes {
496 result = matmul(x, W);
497 }
498 outputs { result; }
499}
500"#;
501 let graph = parse_wg_text(input).unwrap();
502 assert_eq!(graph.format, "webnn-graph-json");
503 assert_eq!(graph.version, 1);
504 assert_eq!(graph.inputs.len(), 1);
505 assert_eq!(graph.consts.len(), 1);
506 assert_eq!(graph.nodes.len(), 1);
507 assert_eq!(graph.outputs.len(), 1);
508 }
509
510 #[test]
511 fn test_parse_inputs() {
512 let input = r#"
513webnn_graph "test" v1 {
514 inputs {
515 x: f32[1, 10];
516 y: i32[5];
517 }
518 nodes {}
519 outputs { x; }
520}
521"#;
522 let graph = parse_wg_text(input).unwrap();
523 assert_eq!(graph.inputs.len(), 2);
524 assert!(graph.inputs.contains_key("x"));
525 assert!(graph.inputs.contains_key("y"));
526
527 let x_desc = &graph.inputs["x"];
528 assert_eq!(x_desc.data_type, DataType::Float32);
529 assert_eq!(
530 x_desc.shape,
531 vec![Dimension::Static(1), Dimension::Static(10)]
532 );
533
534 let y_desc = &graph.inputs["y"];
535 assert_eq!(y_desc.data_type, DataType::Int32);
536 assert_eq!(y_desc.shape, vec![Dimension::Static(5)]);
537 }
538
539 #[test]
540 fn test_parse_dynamic_input_shape() {
541 let input = r#"
542webnn_graph "test" v2 {
543 inputs {
544 x: f32[dyn("batch_size", 8), 128];
545 }
546 nodes {}
547 outputs { x; }
548}
549"#;
550 let graph = parse_wg_text(input).unwrap();
551 let x_desc = &graph.inputs["x"];
552 assert!(matches!(
553 &x_desc.shape[0],
554 Dimension::Dynamic(d) if d.name == "batch_size" && d.max_size == 8
555 ));
556 assert!(matches!(&x_desc.shape[1], Dimension::Static(128)));
557 }
558
559 #[test]
560 fn test_parse_consts_with_weights() {
561 let input = r#"
562webnn_graph "test" v1 {
563 inputs { x: f32[1]; }
564 consts {
565 W: f32[10, 5] @weights("W");
566 b: f32[5] @weights("bias");
567 }
568 nodes {}
569 outputs { x; }
570}
571"#;
572 let graph = parse_wg_text(input).unwrap();
573 assert_eq!(graph.consts.len(), 2);
574
575 let w = &graph.consts["W"];
576 assert_eq!(w.data_type, DataType::Float32);
577 assert_eq!(w.shape, vec![10, 5]);
578 assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
579
580 let b = &graph.consts["b"];
581 assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
582 }
583
584 #[test]
585 fn test_parse_consts_with_scalar() {
586 let input = r#"
587webnn_graph "test" v1 {
588 inputs { x: f32[1]; }
589 consts {
590 scale: f32[1] @scalar(2.5);
591 }
592 nodes {}
593 outputs { x; }
594}
595"#;
596 let graph = parse_wg_text(input).unwrap();
597 let scale = &graph.consts["scale"];
598 match &scale.init {
599 ConstInit::Scalar { value } => {
600 assert_eq!(value.as_f64().unwrap(), 2.5);
601 }
602 _ => panic!("Expected scalar init"),
603 }
604 }
605
606 #[test]
607 fn test_parse_nodes() {
608 let input = r#"
609webnn_graph "test" v1 {
610 inputs { x: f32[1, 2048]; }
611 consts { W: f32[2048, 1000] @weights("W"); }
612 nodes {
613 result = matmul(x, W);
614 }
615 outputs { result; }
616}
617"#;
618 let graph = parse_wg_text(input).unwrap();
619 assert_eq!(graph.nodes.len(), 1);
620
621 let node = &graph.nodes[0];
622 assert_eq!(node.id, "result");
623 assert_eq!(node.op, "matmul");
624 assert_eq!(node.inputs, vec!["x", "W"]);
625 assert!(node.options.is_empty());
626 }
627
628 #[test]
629 fn test_parse_nodes_with_options() {
630 let input = r#"
631webnn_graph "test" v1 {
632 inputs { x: f32[1, 10]; }
633 nodes {
634 result = softmax(x, axis=1);
635 }
636 outputs { result; }
637}
638"#;
639 let graph = parse_wg_text(input).unwrap();
640 let node = &graph.nodes[0];
641 assert_eq!(node.op, "softmax");
642 assert_eq!(node.inputs, vec!["x"]);
643 assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
644 }
645
646 #[test]
647 fn test_parse_multi_assign() {
648 let input = r#"
649webnn_graph "test" v1 {
650 inputs { x: f32[10]; }
651 nodes {
652 [a, b] = split(x);
653 }
654 outputs { a; }
655}
656"#;
657 let graph = parse_wg_text(input).unwrap();
658 let node = &graph.nodes[0];
659 assert_eq!(node.id, "a");
660 assert_eq!(node.op, "split");
661 assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
662 }
663
664 #[test]
665 fn test_parse_concat_bracket_input_list() {
666 let input = r#"
667webnn_graph "model" v1 {
668 inputs {
669 tensors_0: f32[2, 3, 4, 5];
670 tensors_1: f32[2, 3, 4, 5];
671 }
672 nodes {
673 [operand_1] = concat([tensors_0, tensors_1], axis=0);
674 }
675 outputs { operand_1; }
676}
677"#;
678 let graph = parse_wg_text(input).unwrap();
679 let node = &graph.nodes[0];
680 assert_eq!(node.op, "concat");
681 assert_eq!(node.inputs, vec!["tensors_0", "tensors_1"]);
682 assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 0);
683 }
684
685 #[test]
686 fn test_parse_outputs() {
687 let input = r#"
688webnn_graph "test" v1 {
689 inputs { x: f32[1]; }
690 nodes {
691 a = relu(x);
692 b = sigmoid(x);
693 }
694 outputs { a; b; }
695}
696"#;
697 let graph = parse_wg_text(input).unwrap();
698 assert_eq!(graph.outputs.len(), 2);
699 assert_eq!(graph.outputs.get("a").unwrap(), "a");
700 assert_eq!(graph.outputs.get("b").unwrap(), "b");
701 }
702
703 #[test]
704 fn test_parse_invalid_dtype() {
705 let input = r#"
706webnn_graph "test" v1 {
707 inputs { x: float32[1]; }
708 nodes {}
709 outputs { x; }
710}
711"#;
712 let result = parse_wg_text(input);
713 assert!(result.is_err());
714 match result {
716 Err(ParseError::Pest(_)) => {}
717 Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
718 Ok(_) => panic!("Expected error but parsing succeeded"),
719 }
720 }
721
722 #[test]
723 fn test_unquote() {
724 assert_eq!(unquote(r#""hello""#), "hello");
725 assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
726 assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
727 assert_eq!(unquote("no_quotes"), "no_quotes");
728 }
729
730 #[test]
731 fn test_parse_number_value() {
732 let int_val = parse_number_value("42");
733 assert_eq!(int_val.as_i64().unwrap(), 42);
734
735 let float_val = parse_number_value("3.12");
736 assert_eq!(float_val.as_f64().unwrap(), 3.12);
737
738 let sci_val = parse_number_value("1e-3");
739 assert_eq!(sci_val.as_f64().unwrap(), 0.001);
740 }
741
742 #[test]
743 fn test_parse_dollar_sign_identifiers() {
744 let input = r#"
745webnn_graph "test" v1 {
746 inputs {
747 x: f32[1, 10];
748 }
749 consts {
750 $_weight: f32[10, 5] @weights("W");
751 }
752 nodes {
753 $_temp = relu(x);
754 result = matmul($_temp, $_weight);
755 }
756 outputs { result; }
757}
758"#;
759 let graph = parse_wg_text(input).unwrap();
760 assert_eq!(graph.inputs.len(), 1);
761 assert_eq!(graph.consts.len(), 1);
762 assert!(graph.consts.contains_key("$_weight"));
763 assert_eq!(graph.nodes.len(), 2);
764 assert_eq!(graph.nodes[0].id, "$_temp");
765 assert_eq!(graph.nodes[1].id, "result");
766 assert_eq!(graph.nodes[1].inputs, vec!["$_temp", "$_weight"]);
767 }
768}