webnn_graph/
ast.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct GraphJson {
6    pub format: String, // "webnn-graph-json"
7    pub version: u32,   // 1
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub name: Option<String>,
10    pub inputs: BTreeMap<String, OperandDesc>,
11    #[serde(default)]
12    pub consts: BTreeMap<String, ConstDecl>,
13    pub nodes: Vec<Node>,
14    // output_name -> value reference name
15    pub outputs: BTreeMap<String, String>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub struct OperandDesc {
20    #[serde(rename = "dataType")]
21    pub data_type: DataType,
22    pub shape: Vec<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub enum DataType {
27    #[serde(rename = "float32")]
28    Float32,
29    #[serde(rename = "float16")]
30    Float16,
31    #[serde(rename = "int32")]
32    Int32,
33    #[serde(rename = "uint32")]
34    Uint32,
35    #[serde(rename = "int64")]
36    Int64,
37    #[serde(rename = "uint64")]
38    Uint64,
39    #[serde(rename = "int8")]
40    Int8,
41    #[serde(rename = "uint8")]
42    Uint8,
43}
44
45impl DataType {
46    pub fn from_wg(s: &str) -> Option<Self> {
47        match s {
48            "f32" => Some(Self::Float32),
49            "f16" => Some(Self::Float16),
50            "i32" => Some(Self::Int32),
51            "u32" => Some(Self::Uint32),
52            "i64" => Some(Self::Int64),
53            "u64" => Some(Self::Uint64),
54            "i8" => Some(Self::Int8),
55            "u8" => Some(Self::Uint8),
56            _ => None,
57        }
58    }
59
60    pub fn to_wg_text(&self) -> &'static str {
61        match self {
62            Self::Float32 => "f32",
63            Self::Float16 => "f16",
64            Self::Int32 => "i32",
65            Self::Uint32 => "u32",
66            Self::Int64 => "i64",
67            Self::Uint64 => "u64",
68            Self::Int8 => "i8",
69            Self::Uint8 => "u8",
70        }
71    }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75pub struct ConstDecl {
76    #[serde(rename = "dataType")]
77    pub data_type: DataType,
78    pub shape: Vec<u32>,
79    pub init: ConstInit,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
83#[serde(tag = "kind", rename_all = "camelCase")]
84pub enum ConstInit {
85    Weights { r#ref: String },
86    Scalar { value: serde_json::Value },
87    InlineBytes { bytes: Vec<u8> },
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Node {
92    pub id: String,
93    pub op: String,
94    pub inputs: Vec<String>,
95    #[serde(default)]
96    pub options: serde_json::Map<String, serde_json::Value>,
97    #[serde(default)]
98    pub outputs: Option<Vec<String>>,
99}
100
101pub fn new_graph_json() -> GraphJson {
102    GraphJson {
103        format: "webnn-graph-json".to_string(),
104        version: 1,
105        name: None,
106        inputs: BTreeMap::new(),
107        consts: BTreeMap::new(),
108        nodes: Vec::new(),
109        outputs: BTreeMap::new(),
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_datatype_from_wg() {
119        assert_eq!(DataType::from_wg("f32"), Some(DataType::Float32));
120        assert_eq!(DataType::from_wg("f16"), Some(DataType::Float16));
121        assert_eq!(DataType::from_wg("i32"), Some(DataType::Int32));
122        assert_eq!(DataType::from_wg("u32"), Some(DataType::Uint32));
123        assert_eq!(DataType::from_wg("i64"), Some(DataType::Int64));
124        assert_eq!(DataType::from_wg("u64"), Some(DataType::Uint64));
125        assert_eq!(DataType::from_wg("i8"), Some(DataType::Int8));
126        assert_eq!(DataType::from_wg("u8"), Some(DataType::Uint8));
127        assert_eq!(DataType::from_wg("invalid"), None);
128        assert_eq!(DataType::from_wg("float32"), None);
129    }
130
131    #[test]
132    fn test_new_graph_json() {
133        let graph = new_graph_json();
134        assert_eq!(graph.format, "webnn-graph-json");
135        assert_eq!(graph.version, 1);
136        assert!(graph.inputs.is_empty());
137        assert!(graph.consts.is_empty());
138        assert!(graph.nodes.is_empty());
139        assert!(graph.outputs.is_empty());
140    }
141
142    #[test]
143    fn test_operand_desc_equality() {
144        let desc1 = OperandDesc {
145            data_type: DataType::Float32,
146            shape: vec![1, 2, 3],
147        };
148        let desc2 = OperandDesc {
149            data_type: DataType::Float32,
150            shape: vec![1, 2, 3],
151        };
152        let desc3 = OperandDesc {
153            data_type: DataType::Float16,
154            shape: vec![1, 2, 3],
155        };
156        assert_eq!(desc1, desc2);
157        assert_ne!(desc1, desc3);
158    }
159
160    #[test]
161    fn test_const_init_variants() {
162        let weights_init = ConstInit::Weights {
163            r#ref: "W".to_string(),
164        };
165        let scalar_init = ConstInit::Scalar {
166            value: serde_json::json!(1.0),
167        };
168        let bytes_init = ConstInit::InlineBytes {
169            bytes: vec![1, 2, 3, 4],
170        };
171
172        // Test that they're different variants
173        assert!(matches!(weights_init, ConstInit::Weights { .. }));
174        assert!(matches!(scalar_init, ConstInit::Scalar { .. }));
175        assert!(matches!(bytes_init, ConstInit::InlineBytes { .. }));
176    }
177}