1use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, GraphJson, Node, OperandDesc};
2use pest::iterators::Pair;
3use pest::Parser;
4use pest_derive::Parser;
5use serde_json::{Map, Value};
6use std::collections::BTreeMap;
7use thiserror::Error;
8
9#[derive(Parser)]
10#[grammar = "wg.pest"]
11struct WGParser;
12
13#[derive(Debug, Error)]
14pub enum ParseError {
15 #[error("parse error: {0}")]
16 Pest(Box<pest::error::Error<Rule>>),
17 #[error("invalid dtype: {0}")]
18 BadDType(String),
19 #[error("internal error: {0}")]
20 Internal(String),
21}
22
23impl From<pest::error::Error<Rule>> for ParseError {
24 fn from(err: pest::error::Error<Rule>) -> Self {
25 ParseError::Pest(Box::new(err))
26 }
27}
28
29type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
30
31pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
32 let mut pairs = WGParser::parse(Rule::file, input)?;
33 let file = pairs
34 .next()
35 .ok_or_else(|| ParseError::Internal("missing file".into()))?;
36
37 let mut g = new_graph_json();
38 let mut nodes: Vec<Node> = Vec::new();
39
40 for p in file.into_inner() {
41 match p.as_rule() {
42 Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
43 Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
44 Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
45 Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
46 _ => {}
47 }
48 }
49
50 g.nodes = nodes;
51 Ok(g)
52}
53
54fn parse_inputs_block(
55 p: Pair<Rule>,
56 out: &mut BTreeMap<String, OperandDesc>,
57) -> Result<(), ParseError> {
58 for inner in p.into_inner() {
59 if inner.as_rule() == Rule::input_decl {
60 let mut it = inner.into_inner();
61 let name = it.next().unwrap().as_str().to_string();
62 let (dt, shape) = parse_ty(it.next().unwrap())?;
63 out.insert(
64 name,
65 OperandDesc {
66 data_type: dt,
67 shape,
68 },
69 );
70 }
71 }
72 Ok(())
73}
74
75fn parse_consts_block(
76 p: Pair<Rule>,
77 out: &mut BTreeMap<String, ConstDecl>,
78) -> Result<(), ParseError> {
79 for inner in p.into_inner() {
80 if inner.as_rule() == Rule::const_decl {
81 let mut it = inner.into_inner();
82 let name = it.next().unwrap().as_str().to_string();
83 let (dt, shape) = parse_ty(it.next().unwrap())?;
84
85 let mut init: Option<ConstInit> = None;
86 for ann in it {
87 if ann.as_rule() == Rule::const_annot {
88 let text = ann.as_str();
89 if text.starts_with("@weights") {
90 let s = ann
91 .into_inner()
92 .find(|p| p.as_rule() == Rule::string)
93 .map(|p| unquote(p.as_str()))
94 .unwrap_or_else(|| name.clone());
95 init = Some(ConstInit::Weights { r#ref: s });
96 } else if text.starts_with("@scalar") {
97 let n = ann
98 .into_inner()
99 .find(|p| p.as_rule() == Rule::number)
100 .map(|p| parse_number_value(p.as_str()))
101 .unwrap_or(Value::Null);
102 init = Some(ConstInit::Scalar { value: n });
103 }
104 }
105 }
106
107 let init = init.unwrap_or(ConstInit::Weights {
108 r#ref: name.clone(),
109 });
110 out.insert(
111 name,
112 ConstDecl {
113 data_type: dt,
114 shape,
115 init,
116 },
117 );
118 }
119 }
120 Ok(())
121}
122
123fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
124 for inner in p.into_inner() {
125 if inner.as_rule() != Rule::stmt {
126 continue;
127 }
128 let stmt = inner.into_inner().next().unwrap();
129 match stmt.as_rule() {
130 Rule::assign => out.push(parse_assign(stmt)?),
131 Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
132 _ => {}
133 }
134 }
135 Ok(())
136}
137
138fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
139 let mut it = p.into_inner();
140 let id = it.next().unwrap().as_str().to_string();
141 let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
142 Ok(Node {
143 id,
144 op,
145 inputs,
146 options,
147 outputs,
148 })
149}
150
151fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
152 let mut it = p.into_inner();
153 let mut outs: Vec<String> = Vec::new();
154
155 while let Some(next) = it.peek() {
159 if next.as_rule() == Rule::expr {
160 break;
161 }
162 let t = it.next().unwrap();
163 if t.as_rule() == Rule::ident {
164 outs.push(t.as_str().to_string());
165 }
166 }
167
168 let expr = it
169 .next()
170 .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
171 let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
172 let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
174 Ok(Node {
175 id,
176 op,
177 inputs,
178 options,
179 outputs: Some(outs),
180 })
181}
182
183fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
184 match p.as_rule() {
185 Rule::expr => parse_expr(p.into_inner().next().unwrap()),
186 Rule::call => parse_call(p),
187 Rule::ident => Ok((
188 String::new(),
189 vec![p.as_str().to_string()],
190 Map::new(),
191 None,
192 )),
193 _ => Err(ParseError::Internal(format!(
194 "unexpected expr rule: {:?}",
195 p.as_rule()
196 ))),
197 }
198}
199
200fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
201 let mut it = p.into_inner();
202 let op = it.next().unwrap().as_str().to_string();
203 let mut inputs: Vec<String> = Vec::new();
204 let mut options: Map<String, Value> = Map::new();
205
206 if let Some(args) = it.next() {
207 if args.as_rule() == Rule::args {
208 for arg in args.into_inner() {
209 if arg.as_rule() != Rule::arg {
210 continue;
211 }
212 let mut a = arg.into_inner().peekable();
213
214 let first = match a.next() {
216 Some(f) => f,
217 None => continue,
218 };
219
220 if first.as_rule() == Rule::ident
221 && a.peek().is_some()
222 && a.peek().unwrap().as_rule() == Rule::value
223 {
224 let key = first.as_str().to_string();
226 let val = parse_value(a.next().unwrap())?;
227 options.insert(key, val);
228 } else {
229 match first.as_rule() {
231 Rule::value => {
232 let v = parse_value(first)?;
233 if let Value::String(s) = v {
234 inputs.push(s);
235 } else if let Some(sym) = v.as_str() {
236 inputs.push(sym.to_string());
237 }
238 }
239 Rule::ident => inputs.push(first.as_str().to_string()),
240 _ => {}
241 }
242 }
243 }
244 }
245 }
246
247 Ok((op, inputs, options, None))
248}
249
250fn parse_outputs_block(
251 p: Pair<Rule>,
252 out: &mut BTreeMap<String, String>,
253) -> Result<(), ParseError> {
254 for inner in p.into_inner() {
257 if inner.as_rule() == Rule::output_item {
258 for item in inner.into_inner() {
259 if item.as_rule() == Rule::ident {
260 let name = item.as_str().to_string();
261 out.insert(name.clone(), name);
262 }
263 }
264 }
265 }
266 Ok(())
267}
268
269fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<u32>), ParseError> {
270 let mut it = p.into_inner();
271 let dt_s = it.next().unwrap().as_str();
272 let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
273 let shape = parse_shape(it.next().unwrap())?;
274 Ok((dt, shape))
275}
276
277fn parse_shape(p: Pair<Rule>) -> Result<Vec<u32>, ParseError> {
278 let mut shape = Vec::new();
279 for inner in p.into_inner() {
280 if inner.as_rule() == Rule::int {
281 let v: u32 = inner
282 .as_str()
283 .parse()
284 .map_err(|_| ParseError::Internal("bad int".into()))?;
285 shape.push(v);
286 }
287 }
288 Ok(shape)
289}
290
291fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
292 match p.as_rule() {
293 Rule::value => parse_value(p.into_inner().next().unwrap()),
294 Rule::literal => parse_value(p.into_inner().next().unwrap()),
295 Rule::string => Ok(Value::String(unquote(p.as_str()))),
296 Rule::number => Ok(parse_number_value(p.as_str())),
297 Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
298 Rule::null => Ok(Value::Null),
299 Rule::array => {
300 let mut arr = Vec::new();
301 for inner in p.into_inner() {
302 if inner.as_rule() == Rule::value {
303 arr.push(parse_value(inner)?);
304 }
305 }
306 Ok(Value::Array(arr))
307 }
308 Rule::ident => Ok(Value::String(p.as_str().to_string())),
309 _ => Err(ParseError::Internal(format!(
310 "unexpected value rule: {:?}",
311 p.as_rule()
312 ))),
313 }
314}
315
316fn parse_number_value(s: &str) -> Value {
317 if !s.contains('.') && !s.contains('e') && !s.contains('E') {
319 if let Ok(i) = s.parse::<i64>() {
320 return Value::Number(i.into());
321 }
322 }
323 Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
324}
325
326fn unquote(s: &str) -> String {
327 let mut t = s.to_string();
328 if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
329 t.remove(0);
330 t.pop();
331 }
332 t.replace("\\\"", "\"").replace("\\\\", "\\")
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_parse_simple_graph() {
341 let input = r#"
342webnn_graph "test" v1 {
343 inputs {
344 x: f32[1, 10];
345 }
346 consts {
347 W: f32[10, 5] @weights("W");
348 }
349 nodes {
350 result = matmul(x, W);
351 }
352 outputs { result; }
353}
354"#;
355 let graph = parse_wg_text(input).unwrap();
356 assert_eq!(graph.format, "webnn-graph-json");
357 assert_eq!(graph.version, 1);
358 assert_eq!(graph.inputs.len(), 1);
359 assert_eq!(graph.consts.len(), 1);
360 assert_eq!(graph.nodes.len(), 1);
361 assert_eq!(graph.outputs.len(), 1);
362 }
363
364 #[test]
365 fn test_parse_inputs() {
366 let input = r#"
367webnn_graph "test" v1 {
368 inputs {
369 x: f32[1, 10];
370 y: i32[5];
371 }
372 nodes {}
373 outputs { x; }
374}
375"#;
376 let graph = parse_wg_text(input).unwrap();
377 assert_eq!(graph.inputs.len(), 2);
378 assert!(graph.inputs.contains_key("x"));
379 assert!(graph.inputs.contains_key("y"));
380
381 let x_desc = &graph.inputs["x"];
382 assert_eq!(x_desc.data_type, DataType::Float32);
383 assert_eq!(x_desc.shape, vec![1, 10]);
384
385 let y_desc = &graph.inputs["y"];
386 assert_eq!(y_desc.data_type, DataType::Int32);
387 assert_eq!(y_desc.shape, vec![5]);
388 }
389
390 #[test]
391 fn test_parse_consts_with_weights() {
392 let input = r#"
393webnn_graph "test" v1 {
394 inputs { x: f32[1]; }
395 consts {
396 W: f32[10, 5] @weights("W");
397 b: f32[5] @weights("bias");
398 }
399 nodes {}
400 outputs { x; }
401}
402"#;
403 let graph = parse_wg_text(input).unwrap();
404 assert_eq!(graph.consts.len(), 2);
405
406 let w = &graph.consts["W"];
407 assert_eq!(w.data_type, DataType::Float32);
408 assert_eq!(w.shape, vec![10, 5]);
409 assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
410
411 let b = &graph.consts["b"];
412 assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
413 }
414
415 #[test]
416 fn test_parse_consts_with_scalar() {
417 let input = r#"
418webnn_graph "test" v1 {
419 inputs { x: f32[1]; }
420 consts {
421 scale: f32[1] @scalar(2.5);
422 }
423 nodes {}
424 outputs { x; }
425}
426"#;
427 let graph = parse_wg_text(input).unwrap();
428 let scale = &graph.consts["scale"];
429 match &scale.init {
430 ConstInit::Scalar { value } => {
431 assert_eq!(value.as_f64().unwrap(), 2.5);
432 }
433 _ => panic!("Expected scalar init"),
434 }
435 }
436
437 #[test]
438 fn test_parse_nodes() {
439 let input = r#"
440webnn_graph "test" v1 {
441 inputs { x: f32[1, 2048]; }
442 consts { W: f32[2048, 1000] @weights("W"); }
443 nodes {
444 result = matmul(x, W);
445 }
446 outputs { result; }
447}
448"#;
449 let graph = parse_wg_text(input).unwrap();
450 assert_eq!(graph.nodes.len(), 1);
451
452 let node = &graph.nodes[0];
453 assert_eq!(node.id, "result");
454 assert_eq!(node.op, "matmul");
455 assert_eq!(node.inputs, vec!["x", "W"]);
456 assert!(node.options.is_empty());
457 }
458
459 #[test]
460 fn test_parse_nodes_with_options() {
461 let input = r#"
462webnn_graph "test" v1 {
463 inputs { x: f32[1, 10]; }
464 nodes {
465 result = softmax(x, axis=1);
466 }
467 outputs { result; }
468}
469"#;
470 let graph = parse_wg_text(input).unwrap();
471 let node = &graph.nodes[0];
472 assert_eq!(node.op, "softmax");
473 assert_eq!(node.inputs, vec!["x"]);
474 assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
475 }
476
477 #[test]
478 fn test_parse_multi_assign() {
479 let input = r#"
480webnn_graph "test" v1 {
481 inputs { x: f32[10]; }
482 nodes {
483 [a, b] = split(x);
484 }
485 outputs { a; }
486}
487"#;
488 let graph = parse_wg_text(input).unwrap();
489 let node = &graph.nodes[0];
490 assert_eq!(node.id, "a");
491 assert_eq!(node.op, "split");
492 assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
493 }
494
495 #[test]
496 fn test_parse_outputs() {
497 let input = r#"
498webnn_graph "test" v1 {
499 inputs { x: f32[1]; }
500 nodes {
501 a = relu(x);
502 b = sigmoid(x);
503 }
504 outputs { a; b; }
505}
506"#;
507 let graph = parse_wg_text(input).unwrap();
508 assert_eq!(graph.outputs.len(), 2);
509 assert_eq!(graph.outputs.get("a").unwrap(), "a");
510 assert_eq!(graph.outputs.get("b").unwrap(), "b");
511 }
512
513 #[test]
514 fn test_parse_invalid_dtype() {
515 let input = r#"
516webnn_graph "test" v1 {
517 inputs { x: float32[1]; }
518 nodes {}
519 outputs { x; }
520}
521"#;
522 let result = parse_wg_text(input);
523 assert!(result.is_err());
524 match result {
526 Err(ParseError::Pest(_)) => {}
527 Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
528 Ok(_) => panic!("Expected error but parsing succeeded"),
529 }
530 }
531
532 #[test]
533 fn test_unquote() {
534 assert_eq!(unquote(r#""hello""#), "hello");
535 assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
536 assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
537 assert_eq!(unquote("no_quotes"), "no_quotes");
538 }
539
540 #[test]
541 fn test_parse_number_value() {
542 let int_val = parse_number_value("42");
543 assert_eq!(int_val.as_i64().unwrap(), 42);
544
545 let float_val = parse_number_value("3.12");
546 assert_eq!(float_val.as_f64().unwrap(), 3.12);
547
548 let sci_val = parse_number_value("1e-3");
549 assert_eq!(sci_val.as_f64().unwrap(), 0.001);
550 }
551}