1use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, GraphJson, Node, OperandDesc};
2use pest::iterators::Pair;
3use pest::Parser;
4use pest_derive::Parser;
5use serde_json::{Map, Value};
6use std::collections::BTreeMap;
7use thiserror::Error;
8
9#[derive(Parser)]
10#[grammar = "wg.pest"]
11struct WGParser;
12
13#[derive(Debug, Error)]
14pub enum ParseError {
15 #[error("parse error: {0}")]
16 Pest(Box<pest::error::Error<Rule>>),
17 #[error("invalid dtype: {0}")]
18 BadDType(String),
19 #[error("internal error: {0}")]
20 Internal(String),
21}
22
23impl From<pest::error::Error<Rule>> for ParseError {
24 fn from(err: pest::error::Error<Rule>) -> Self {
25 ParseError::Pest(Box::new(err))
26 }
27}
28
29type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
30
31pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
32 let mut pairs = WGParser::parse(Rule::file, input)?;
33 let file = pairs
34 .next()
35 .ok_or_else(|| ParseError::Internal("missing file".into()))?;
36
37 let mut g = new_graph_json();
38 let mut nodes: Vec<Node> = Vec::new();
39
40 for p in file.into_inner() {
41 match p.as_rule() {
42 Rule::header => {
43 for inner in p.into_inner() {
45 if inner.as_rule() == Rule::string {
46 g.name = Some(unquote(inner.as_str()));
47 break;
48 }
49 }
50 }
51 Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
52 Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
53 Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
54 Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
55 _ => {}
56 }
57 }
58
59 g.nodes = nodes;
60 Ok(g)
61}
62
63fn parse_inputs_block(
64 p: Pair<Rule>,
65 out: &mut BTreeMap<String, OperandDesc>,
66) -> Result<(), ParseError> {
67 for inner in p.into_inner() {
68 if inner.as_rule() == Rule::input_decl {
69 let mut it = inner.into_inner();
70 let name = it.next().unwrap().as_str().to_string();
71 let (dt, shape) = parse_ty(it.next().unwrap())?;
72 out.insert(
73 name,
74 OperandDesc {
75 data_type: dt,
76 shape,
77 },
78 );
79 }
80 }
81 Ok(())
82}
83
84fn parse_consts_block(
85 p: Pair<Rule>,
86 out: &mut BTreeMap<String, ConstDecl>,
87) -> Result<(), ParseError> {
88 for inner in p.into_inner() {
89 if inner.as_rule() == Rule::const_decl {
90 let mut it = inner.into_inner();
91 let name = it.next().unwrap().as_str().to_string();
92 let (dt, shape) = parse_ty(it.next().unwrap())?;
93
94 let mut init: Option<ConstInit> = None;
95 for ann in it {
96 if ann.as_rule() == Rule::const_annot {
97 let text = ann.as_str();
98 if text.starts_with("@weights") {
99 let s = ann
100 .into_inner()
101 .find(|p| p.as_rule() == Rule::string)
102 .map(|p| unquote(p.as_str()))
103 .unwrap_or_else(|| name.clone());
104 init = Some(ConstInit::Weights { r#ref: s });
105 } else if text.starts_with("@scalar") {
106 let n = ann
107 .into_inner()
108 .find(|p| p.as_rule() == Rule::number)
109 .map(|p| parse_number_value(p.as_str()))
110 .unwrap_or(Value::Null);
111 init = Some(ConstInit::Scalar { value: n });
112 }
113 }
114 }
115
116 let init = init.unwrap_or(ConstInit::Weights {
117 r#ref: name.clone(),
118 });
119 out.insert(
120 name,
121 ConstDecl {
122 data_type: dt,
123 shape,
124 init,
125 },
126 );
127 }
128 }
129 Ok(())
130}
131
132fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
133 for inner in p.into_inner() {
134 if inner.as_rule() != Rule::stmt {
135 continue;
136 }
137 let stmt = inner.into_inner().next().unwrap();
138 match stmt.as_rule() {
139 Rule::assign => out.push(parse_assign(stmt)?),
140 Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
141 _ => {}
142 }
143 }
144 Ok(())
145}
146
147fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
148 let mut it = p.into_inner();
149 let id = it.next().unwrap().as_str().to_string();
150 let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
151 Ok(Node {
152 id,
153 op,
154 inputs,
155 options,
156 outputs,
157 })
158}
159
160fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
161 let mut it = p.into_inner();
162 let mut outs: Vec<String> = Vec::new();
163
164 while let Some(next) = it.peek() {
168 if next.as_rule() == Rule::expr {
169 break;
170 }
171 let t = it.next().unwrap();
172 if t.as_rule() == Rule::ident {
173 outs.push(t.as_str().to_string());
174 }
175 }
176
177 let expr = it
178 .next()
179 .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
180 let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
181 let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
183 Ok(Node {
184 id,
185 op,
186 inputs,
187 options,
188 outputs: Some(outs),
189 })
190}
191
192fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
193 match p.as_rule() {
194 Rule::expr => parse_expr(p.into_inner().next().unwrap()),
195 Rule::call => parse_call(p),
196 Rule::ident => Ok((
197 String::new(),
198 vec![p.as_str().to_string()],
199 Map::new(),
200 None,
201 )),
202 _ => Err(ParseError::Internal(format!(
203 "unexpected expr rule: {:?}",
204 p.as_rule()
205 ))),
206 }
207}
208
209fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
210 let mut it = p.into_inner();
211 let op = it.next().unwrap().as_str().to_string();
212 let mut inputs: Vec<String> = Vec::new();
213 let mut options: Map<String, Value> = Map::new();
214
215 if let Some(args) = it.next() {
216 if args.as_rule() == Rule::args {
217 for arg in args.into_inner() {
218 if arg.as_rule() != Rule::arg {
219 continue;
220 }
221 let mut a = arg.into_inner().peekable();
222
223 let first = match a.next() {
225 Some(f) => f,
226 None => continue,
227 };
228
229 if first.as_rule() == Rule::ident
230 && a.peek().is_some()
231 && a.peek().unwrap().as_rule() == Rule::value
232 {
233 let key = first.as_str().to_string();
235 let val = parse_value(a.next().unwrap())?;
236 options.insert(key, val);
237 } else {
238 match first.as_rule() {
240 Rule::value => {
241 let v = parse_value(first)?;
242 if let Value::String(s) = v {
243 inputs.push(s);
244 } else if let Some(sym) = v.as_str() {
245 inputs.push(sym.to_string());
246 }
247 }
248 Rule::ident => inputs.push(first.as_str().to_string()),
249 _ => {}
250 }
251 }
252 }
253 }
254 }
255
256 Ok((op, inputs, options, None))
257}
258
259fn parse_outputs_block(
260 p: Pair<Rule>,
261 out: &mut BTreeMap<String, String>,
262) -> Result<(), ParseError> {
263 for inner in p.into_inner() {
266 if inner.as_rule() == Rule::output_item {
267 for item in inner.into_inner() {
268 if item.as_rule() == Rule::ident {
269 let name = item.as_str().to_string();
270 out.insert(name.clone(), name);
271 }
272 }
273 }
274 }
275 Ok(())
276}
277
278fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<u32>), ParseError> {
279 let mut it = p.into_inner();
280 let dt_s = it.next().unwrap().as_str();
281 let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
282 let shape = parse_shape(it.next().unwrap())?;
283 Ok((dt, shape))
284}
285
286fn parse_shape(p: Pair<Rule>) -> Result<Vec<u32>, ParseError> {
287 let mut shape = Vec::new();
288 for inner in p.into_inner() {
289 if inner.as_rule() == Rule::int {
290 let v: u32 = inner
291 .as_str()
292 .parse()
293 .map_err(|_| ParseError::Internal("bad int".into()))?;
294 shape.push(v);
295 }
296 }
297 Ok(shape)
298}
299
300fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
301 match p.as_rule() {
302 Rule::value => parse_value(p.into_inner().next().unwrap()),
303 Rule::literal => parse_value(p.into_inner().next().unwrap()),
304 Rule::string => Ok(Value::String(unquote(p.as_str()))),
305 Rule::number => Ok(parse_number_value(p.as_str())),
306 Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
307 Rule::null => Ok(Value::Null),
308 Rule::array => {
309 let mut arr = Vec::new();
310 for inner in p.into_inner() {
311 if inner.as_rule() == Rule::value {
312 arr.push(parse_value(inner)?);
313 }
314 }
315 Ok(Value::Array(arr))
316 }
317 Rule::ident => Ok(Value::String(p.as_str().to_string())),
318 _ => Err(ParseError::Internal(format!(
319 "unexpected value rule: {:?}",
320 p.as_rule()
321 ))),
322 }
323}
324
325fn parse_number_value(s: &str) -> Value {
326 if !s.contains('.') && !s.contains('e') && !s.contains('E') {
328 if let Ok(i) = s.parse::<i64>() {
329 return Value::Number(i.into());
330 }
331 }
332 Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
333}
334
335fn unquote(s: &str) -> String {
336 let mut t = s.to_string();
337 if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
338 t.remove(0);
339 t.pop();
340 }
341 t.replace("\\\"", "\"").replace("\\\\", "\\")
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_parse_simple_graph() {
350 let input = r#"
351webnn_graph "test" v1 {
352 inputs {
353 x: f32[1, 10];
354 }
355 consts {
356 W: f32[10, 5] @weights("W");
357 }
358 nodes {
359 result = matmul(x, W);
360 }
361 outputs { result; }
362}
363"#;
364 let graph = parse_wg_text(input).unwrap();
365 assert_eq!(graph.format, "webnn-graph-json");
366 assert_eq!(graph.version, 1);
367 assert_eq!(graph.inputs.len(), 1);
368 assert_eq!(graph.consts.len(), 1);
369 assert_eq!(graph.nodes.len(), 1);
370 assert_eq!(graph.outputs.len(), 1);
371 }
372
373 #[test]
374 fn test_parse_inputs() {
375 let input = r#"
376webnn_graph "test" v1 {
377 inputs {
378 x: f32[1, 10];
379 y: i32[5];
380 }
381 nodes {}
382 outputs { x; }
383}
384"#;
385 let graph = parse_wg_text(input).unwrap();
386 assert_eq!(graph.inputs.len(), 2);
387 assert!(graph.inputs.contains_key("x"));
388 assert!(graph.inputs.contains_key("y"));
389
390 let x_desc = &graph.inputs["x"];
391 assert_eq!(x_desc.data_type, DataType::Float32);
392 assert_eq!(x_desc.shape, vec![1, 10]);
393
394 let y_desc = &graph.inputs["y"];
395 assert_eq!(y_desc.data_type, DataType::Int32);
396 assert_eq!(y_desc.shape, vec![5]);
397 }
398
399 #[test]
400 fn test_parse_consts_with_weights() {
401 let input = r#"
402webnn_graph "test" v1 {
403 inputs { x: f32[1]; }
404 consts {
405 W: f32[10, 5] @weights("W");
406 b: f32[5] @weights("bias");
407 }
408 nodes {}
409 outputs { x; }
410}
411"#;
412 let graph = parse_wg_text(input).unwrap();
413 assert_eq!(graph.consts.len(), 2);
414
415 let w = &graph.consts["W"];
416 assert_eq!(w.data_type, DataType::Float32);
417 assert_eq!(w.shape, vec![10, 5]);
418 assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
419
420 let b = &graph.consts["b"];
421 assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
422 }
423
424 #[test]
425 fn test_parse_consts_with_scalar() {
426 let input = r#"
427webnn_graph "test" v1 {
428 inputs { x: f32[1]; }
429 consts {
430 scale: f32[1] @scalar(2.5);
431 }
432 nodes {}
433 outputs { x; }
434}
435"#;
436 let graph = parse_wg_text(input).unwrap();
437 let scale = &graph.consts["scale"];
438 match &scale.init {
439 ConstInit::Scalar { value } => {
440 assert_eq!(value.as_f64().unwrap(), 2.5);
441 }
442 _ => panic!("Expected scalar init"),
443 }
444 }
445
446 #[test]
447 fn test_parse_nodes() {
448 let input = r#"
449webnn_graph "test" v1 {
450 inputs { x: f32[1, 2048]; }
451 consts { W: f32[2048, 1000] @weights("W"); }
452 nodes {
453 result = matmul(x, W);
454 }
455 outputs { result; }
456}
457"#;
458 let graph = parse_wg_text(input).unwrap();
459 assert_eq!(graph.nodes.len(), 1);
460
461 let node = &graph.nodes[0];
462 assert_eq!(node.id, "result");
463 assert_eq!(node.op, "matmul");
464 assert_eq!(node.inputs, vec!["x", "W"]);
465 assert!(node.options.is_empty());
466 }
467
468 #[test]
469 fn test_parse_nodes_with_options() {
470 let input = r#"
471webnn_graph "test" v1 {
472 inputs { x: f32[1, 10]; }
473 nodes {
474 result = softmax(x, axis=1);
475 }
476 outputs { result; }
477}
478"#;
479 let graph = parse_wg_text(input).unwrap();
480 let node = &graph.nodes[0];
481 assert_eq!(node.op, "softmax");
482 assert_eq!(node.inputs, vec!["x"]);
483 assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
484 }
485
486 #[test]
487 fn test_parse_multi_assign() {
488 let input = r#"
489webnn_graph "test" v1 {
490 inputs { x: f32[10]; }
491 nodes {
492 [a, b] = split(x);
493 }
494 outputs { a; }
495}
496"#;
497 let graph = parse_wg_text(input).unwrap();
498 let node = &graph.nodes[0];
499 assert_eq!(node.id, "a");
500 assert_eq!(node.op, "split");
501 assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
502 }
503
504 #[test]
505 fn test_parse_outputs() {
506 let input = r#"
507webnn_graph "test" v1 {
508 inputs { x: f32[1]; }
509 nodes {
510 a = relu(x);
511 b = sigmoid(x);
512 }
513 outputs { a; b; }
514}
515"#;
516 let graph = parse_wg_text(input).unwrap();
517 assert_eq!(graph.outputs.len(), 2);
518 assert_eq!(graph.outputs.get("a").unwrap(), "a");
519 assert_eq!(graph.outputs.get("b").unwrap(), "b");
520 }
521
522 #[test]
523 fn test_parse_invalid_dtype() {
524 let input = r#"
525webnn_graph "test" v1 {
526 inputs { x: float32[1]; }
527 nodes {}
528 outputs { x; }
529}
530"#;
531 let result = parse_wg_text(input);
532 assert!(result.is_err());
533 match result {
535 Err(ParseError::Pest(_)) => {}
536 Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
537 Ok(_) => panic!("Expected error but parsing succeeded"),
538 }
539 }
540
541 #[test]
542 fn test_unquote() {
543 assert_eq!(unquote(r#""hello""#), "hello");
544 assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
545 assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
546 assert_eq!(unquote("no_quotes"), "no_quotes");
547 }
548
549 #[test]
550 fn test_parse_number_value() {
551 let int_val = parse_number_value("42");
552 assert_eq!(int_val.as_i64().unwrap(), 42);
553
554 let float_val = parse_number_value("3.12");
555 assert_eq!(float_val.as_f64().unwrap(), 3.12);
556
557 let sci_val = parse_number_value("1e-3");
558 assert_eq!(sci_val.as_f64().unwrap(), 0.001);
559 }
560}