Skip to main content

webnn_graph/
ast.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
5#[serde(rename_all = "camelCase")]
6pub struct DynamicDimension {
7    pub name: String,
8    pub max_size: u32,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
12#[serde(untagged)]
13pub enum Dimension {
14    Static(u32),
15    Dynamic(DynamicDimension),
16}
17
18pub fn to_dimension_vector(shape: &[u32]) -> Vec<Dimension> {
19    shape.iter().copied().map(Dimension::Static).collect()
20}
21
22pub fn get_static_or_max_size(dim: &Dimension) -> u32 {
23    match dim {
24        Dimension::Static(v) => *v,
25        Dimension::Dynamic(d) => d.max_size,
26    }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GraphJson {
31    pub format: String, // "webnn-graph-json"
32    pub version: u32,   // 1 or 2
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub name: Option<String>,
35    #[serde(default)]
36    pub quantized: bool,
37    pub inputs: BTreeMap<String, OperandDesc>,
38    #[serde(default)]
39    pub consts: BTreeMap<String, ConstDecl>,
40    pub nodes: Vec<Node>,
41    // output_name -> value reference name
42    pub outputs: BTreeMap<String, String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
46pub struct OperandDesc {
47    #[serde(rename = "dataType")]
48    pub data_type: DataType,
49    pub shape: Vec<Dimension>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub enum DataType {
54    #[serde(rename = "float32")]
55    Float32,
56    #[serde(rename = "float16")]
57    Float16,
58    #[serde(rename = "int4")]
59    Int4,
60    #[serde(rename = "uint4")]
61    Uint4,
62    #[serde(rename = "int32")]
63    Int32,
64    #[serde(rename = "uint32")]
65    Uint32,
66    #[serde(rename = "int64")]
67    Int64,
68    #[serde(rename = "uint64")]
69    Uint64,
70    #[serde(rename = "int8")]
71    Int8,
72    #[serde(rename = "uint8")]
73    Uint8,
74}
75
76impl DataType {
77    pub fn from_wg(s: &str) -> Option<Self> {
78        match s {
79            "f32" => Some(Self::Float32),
80            "f16" => Some(Self::Float16),
81            "i4" => Some(Self::Int4),
82            "u4" => Some(Self::Uint4),
83            "i32" => Some(Self::Int32),
84            "u32" => Some(Self::Uint32),
85            "i64" => Some(Self::Int64),
86            "u64" => Some(Self::Uint64),
87            "i8" => Some(Self::Int8),
88            "u8" => Some(Self::Uint8),
89            _ => None,
90        }
91    }
92
93    pub fn to_wg_text(&self) -> &'static str {
94        match self {
95            Self::Float32 => "f32",
96            Self::Float16 => "f16",
97            Self::Int4 => "i4",
98            Self::Uint4 => "u4",
99            Self::Int32 => "i32",
100            Self::Uint32 => "u32",
101            Self::Int64 => "i64",
102            Self::Uint64 => "u64",
103            Self::Int8 => "i8",
104            Self::Uint8 => "u8",
105        }
106    }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110pub struct ConstDecl {
111    #[serde(rename = "dataType")]
112    pub data_type: DataType,
113    pub shape: Vec<u32>,
114    pub init: ConstInit,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
118#[serde(tag = "kind", rename_all = "camelCase")]
119pub enum ConstInit {
120    Weights { r#ref: String },
121    Scalar { value: serde_json::Value },
122    InlineBytes { bytes: Vec<u8> },
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Node {
127    pub id: String,
128    pub op: String,
129    pub inputs: Vec<String>,
130    #[serde(default)]
131    pub options: serde_json::Map<String, serde_json::Value>,
132    #[serde(default)]
133    pub outputs: Option<Vec<String>>,
134}
135
136pub fn new_graph_json() -> GraphJson {
137    GraphJson {
138        format: "webnn-graph-json".to_string(),
139        version: 2,
140        name: None,
141        quantized: false,
142        inputs: BTreeMap::new(),
143        consts: BTreeMap::new(),
144        nodes: Vec::new(),
145        outputs: BTreeMap::new(),
146    }
147}
148
149impl OperandDesc {
150    pub fn static_shape(&self) -> Option<Vec<u32>> {
151        let mut shape = Vec::with_capacity(self.shape.len());
152        for dim in &self.shape {
153            match dim {
154                Dimension::Static(v) => shape.push(*v),
155                Dimension::Dynamic(_) => return None,
156            }
157        }
158        Some(shape)
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_datatype_from_wg() {
168        assert_eq!(DataType::from_wg("f32"), Some(DataType::Float32));
169        assert_eq!(DataType::from_wg("f16"), Some(DataType::Float16));
170        assert_eq!(DataType::from_wg("i32"), Some(DataType::Int32));
171        assert_eq!(DataType::from_wg("u32"), Some(DataType::Uint32));
172        assert_eq!(DataType::from_wg("i64"), Some(DataType::Int64));
173        assert_eq!(DataType::from_wg("u64"), Some(DataType::Uint64));
174        assert_eq!(DataType::from_wg("i8"), Some(DataType::Int8));
175        assert_eq!(DataType::from_wg("u8"), Some(DataType::Uint8));
176        assert_eq!(DataType::from_wg("invalid"), None);
177        assert_eq!(DataType::from_wg("float32"), None);
178    }
179
180    #[test]
181    fn test_new_graph_json() {
182        let graph = new_graph_json();
183        assert_eq!(graph.format, "webnn-graph-json");
184        assert_eq!(graph.version, 2);
185        assert!(graph.inputs.is_empty());
186        assert!(graph.consts.is_empty());
187        assert!(graph.nodes.is_empty());
188        assert!(graph.outputs.is_empty());
189    }
190
191    #[test]
192    fn test_operand_desc_equality() {
193        let desc1 = OperandDesc {
194            data_type: DataType::Float32,
195            shape: to_dimension_vector(&[1, 2, 3]),
196        };
197        let desc2 = OperandDesc {
198            data_type: DataType::Float32,
199            shape: to_dimension_vector(&[1, 2, 3]),
200        };
201        let desc3 = OperandDesc {
202            data_type: DataType::Float16,
203            shape: to_dimension_vector(&[1, 2, 3]),
204        };
205        assert_eq!(desc1, desc2);
206        assert_ne!(desc1, desc3);
207    }
208
209    #[test]
210    fn test_const_init_variants() {
211        let weights_init = ConstInit::Weights {
212            r#ref: "W".to_string(),
213        };
214        let scalar_init = ConstInit::Scalar {
215            value: serde_json::json!(1.0),
216        };
217        let bytes_init = ConstInit::InlineBytes {
218            bytes: vec![1, 2, 3, 4],
219        };
220
221        // Test that they're different variants
222        assert!(matches!(weights_init, ConstInit::Weights { .. }));
223        assert!(matches!(scalar_init, ConstInit::Scalar { .. }));
224        assert!(matches!(bytes_init, ConstInit::InlineBytes { .. }));
225    }
226}