webnn_graph/
validate.rs

1use std::collections::BTreeSet;
2
3use thiserror::Error;
4
5use crate::ast::{ConstInit, GraphJson};
6use crate::weights::{dtype_size, numel, WeightsManifest};
7
8#[derive(Debug, Error)]
9pub enum ValidateError {
10    #[error("unsupported format/version: {0} v{1}")]
11    BadFormat(String, u32),
12
13    #[error("outputs must be non-empty")]
14    EmptyOutputs,
15
16    #[error("duplicate node id: {0}")]
17    DuplicateNodeId(String),
18
19    #[error("unknown reference '{ref_}' used in node '{node}'")]
20    UnknownRef { node: String, ref_: String },
21
22    #[error("output '{out}' references unknown value '{ref_}'")]
23    BadOutputRef { out: String, ref_: String },
24
25    #[error("missing weights manifest entry for const ref '{0}'")]
26    MissingWeight(String),
27
28    #[error("weights type/shape mismatch for '{ref_}'")]
29    WeightMismatch { ref_: String },
30
31    #[error("weights byteLength mismatch for '{ref_}': expected {expected} got {got}")]
32    WeightByteLengthMismatch {
33        ref_: String,
34        expected: u64,
35        got: u64,
36    },
37}
38
39pub fn validate_graph(g: &GraphJson) -> Result<(), ValidateError> {
40    if g.format != "webnn-graph-json" || g.version != 1 {
41        return Err(ValidateError::BadFormat(g.format.clone(), g.version));
42    }
43    if g.outputs.is_empty() {
44        return Err(ValidateError::EmptyOutputs);
45    }
46
47    let mut known: BTreeSet<String> = g.inputs.keys().cloned().collect();
48    known.extend(g.consts.keys().cloned());
49
50    let mut ids = BTreeSet::new();
51
52    for n in &g.nodes {
53        if !ids.insert(n.id.clone()) {
54            return Err(ValidateError::DuplicateNodeId(n.id.clone()));
55        }
56        for r in &n.inputs {
57            if !known.contains(r) {
58                return Err(ValidateError::UnknownRef {
59                    node: n.id.clone(),
60                    ref_: r.clone(),
61                });
62            }
63        }
64        known.insert(n.id.clone());
65        if let Some(outs) = &n.outputs {
66            for o in outs {
67                known.insert(o.clone());
68            }
69        }
70    }
71
72    for (out, r) in &g.outputs {
73        if !known.contains(r) {
74            return Err(ValidateError::BadOutputRef {
75                out: out.clone(),
76                ref_: r.clone(),
77            });
78        }
79    }
80    Ok(())
81}
82
83pub fn validate_weights(g: &GraphJson, m: &WeightsManifest) -> Result<(), ValidateError> {
84    for c in g.consts.values() {
85        if let ConstInit::Weights { r#ref } = &c.init {
86            let entry = m
87                .tensors
88                .get(r#ref)
89                .ok_or_else(|| ValidateError::MissingWeight(r#ref.clone()))?;
90
91            if entry.data_type != c.data_type || entry.shape != c.shape {
92                return Err(ValidateError::WeightMismatch {
93                    ref_: r#ref.clone(),
94                });
95            }
96
97            let expected = dtype_size(&c.data_type) * numel(&c.shape);
98            if entry.byte_length != expected {
99                return Err(ValidateError::WeightByteLengthMismatch {
100                    ref_: r#ref.clone(),
101                    expected,
102                    got: entry.byte_length,
103                });
104            }
105        }
106    }
107    Ok(())
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::ast::{new_graph_json, ConstDecl, DataType, Node, OperandDesc};
114    use crate::weights::TensorEntry;
115    use std::collections::BTreeMap;
116
117    #[test]
118    fn test_validate_graph_success() {
119        let mut g = new_graph_json();
120        g.inputs.insert(
121            "x".to_string(),
122            OperandDesc {
123                data_type: DataType::Float32,
124                shape: vec![1, 10],
125            },
126        );
127        g.nodes.push(Node {
128            id: "result".to_string(),
129            op: "relu".to_string(),
130            inputs: vec!["x".to_string()],
131            options: serde_json::Map::new(),
132            outputs: None,
133        });
134        g.outputs.insert("result".to_string(), "result".to_string());
135
136        assert!(validate_graph(&g).is_ok());
137    }
138
139    #[test]
140    fn test_validate_graph_bad_format() {
141        let mut g = new_graph_json();
142        g.format = "invalid".to_string();
143        g.outputs.insert("x".to_string(), "x".to_string());
144
145        let result = validate_graph(&g);
146        assert!(matches!(result, Err(ValidateError::BadFormat(_, _))));
147    }
148
149    #[test]
150    fn test_validate_graph_empty_outputs() {
151        let g = new_graph_json();
152        let result = validate_graph(&g);
153        assert!(matches!(result, Err(ValidateError::EmptyOutputs)));
154    }
155
156    #[test]
157    fn test_validate_graph_duplicate_node_id() {
158        let mut g = new_graph_json();
159        g.inputs.insert(
160            "x".to_string(),
161            OperandDesc {
162                data_type: DataType::Float32,
163                shape: vec![1],
164            },
165        );
166        g.nodes.push(Node {
167            id: "result".to_string(),
168            op: "relu".to_string(),
169            inputs: vec!["x".to_string()],
170            options: serde_json::Map::new(),
171            outputs: None,
172        });
173        g.nodes.push(Node {
174            id: "result".to_string(),
175            op: "sigmoid".to_string(),
176            inputs: vec!["x".to_string()],
177            options: serde_json::Map::new(),
178            outputs: None,
179        });
180        g.outputs.insert("result".to_string(), "result".to_string());
181
182        let result = validate_graph(&g);
183        assert!(matches!(result, Err(ValidateError::DuplicateNodeId(_))));
184    }
185
186    #[test]
187    fn test_validate_graph_unknown_ref() {
188        let mut g = new_graph_json();
189        g.inputs.insert(
190            "x".to_string(),
191            OperandDesc {
192                data_type: DataType::Float32,
193                shape: vec![1],
194            },
195        );
196        g.nodes.push(Node {
197            id: "result".to_string(),
198            op: "add".to_string(),
199            inputs: vec!["x".to_string(), "unknown".to_string()],
200            options: serde_json::Map::new(),
201            outputs: None,
202        });
203        g.outputs.insert("result".to_string(), "result".to_string());
204
205        let result = validate_graph(&g);
206        assert!(matches!(result, Err(ValidateError::UnknownRef { .. })));
207    }
208
209    #[test]
210    fn test_validate_graph_bad_output_ref() {
211        let mut g = new_graph_json();
212        g.inputs.insert(
213            "x".to_string(),
214            OperandDesc {
215                data_type: DataType::Float32,
216                shape: vec![1],
217            },
218        );
219        g.outputs
220            .insert("out".to_string(), "nonexistent".to_string());
221
222        let result = validate_graph(&g);
223        assert!(matches!(result, Err(ValidateError::BadOutputRef { .. })));
224    }
225
226    #[test]
227    fn test_validate_weights_success() {
228        let mut g = new_graph_json();
229        g.consts.insert(
230            "W".to_string(),
231            ConstDecl {
232                data_type: DataType::Float32,
233                shape: vec![10, 5],
234                init: ConstInit::Weights {
235                    r#ref: "W".to_string(),
236                },
237            },
238        );
239        g.outputs.insert("x".to_string(), "x".to_string());
240
241        let mut manifest = WeightsManifest {
242            format: "wg-weights-manifest".to_string(),
243            version: 1,
244            endianness: "little".to_string(),
245            tensors: BTreeMap::new(),
246        };
247        manifest.tensors.insert(
248            "W".to_string(),
249            TensorEntry {
250                data_type: DataType::Float32,
251                shape: vec![10, 5],
252                byte_offset: 0,
253                byte_length: 200, // 10 * 5 * 4 bytes
254                layout: None,
255            },
256        );
257
258        assert!(validate_weights(&g, &manifest).is_ok());
259    }
260
261    #[test]
262    fn test_validate_weights_missing_weight() {
263        let mut g = new_graph_json();
264        g.consts.insert(
265            "W".to_string(),
266            ConstDecl {
267                data_type: DataType::Float32,
268                shape: vec![10, 5],
269                init: ConstInit::Weights {
270                    r#ref: "W".to_string(),
271                },
272            },
273        );
274        g.outputs.insert("x".to_string(), "x".to_string());
275
276        let manifest = WeightsManifest {
277            format: "wg-weights-manifest".to_string(),
278            version: 1,
279            endianness: "little".to_string(),
280            tensors: BTreeMap::new(),
281        };
282
283        let result = validate_weights(&g, &manifest);
284        assert!(matches!(result, Err(ValidateError::MissingWeight(_))));
285    }
286
287    #[test]
288    fn test_validate_weights_type_mismatch() {
289        let mut g = new_graph_json();
290        g.consts.insert(
291            "W".to_string(),
292            ConstDecl {
293                data_type: DataType::Float32,
294                shape: vec![10, 5],
295                init: ConstInit::Weights {
296                    r#ref: "W".to_string(),
297                },
298            },
299        );
300        g.outputs.insert("x".to_string(), "x".to_string());
301
302        let mut manifest = WeightsManifest {
303            format: "wg-weights-manifest".to_string(),
304            version: 1,
305            endianness: "little".to_string(),
306            tensors: BTreeMap::new(),
307        };
308        manifest.tensors.insert(
309            "W".to_string(),
310            TensorEntry {
311                data_type: DataType::Float16, // Mismatched type
312                shape: vec![10, 5],
313                byte_offset: 0,
314                byte_length: 100,
315                layout: None,
316            },
317        );
318
319        let result = validate_weights(&g, &manifest);
320        assert!(matches!(result, Err(ValidateError::WeightMismatch { .. })));
321    }
322
323    #[test]
324    fn test_validate_weights_byte_length_mismatch() {
325        let mut g = new_graph_json();
326        g.consts.insert(
327            "W".to_string(),
328            ConstDecl {
329                data_type: DataType::Float32,
330                shape: vec![10, 5],
331                init: ConstInit::Weights {
332                    r#ref: "W".to_string(),
333                },
334            },
335        );
336        g.outputs.insert("x".to_string(), "x".to_string());
337
338        let mut manifest = WeightsManifest {
339            format: "wg-weights-manifest".to_string(),
340            version: 1,
341            endianness: "little".to_string(),
342            tensors: BTreeMap::new(),
343        };
344        manifest.tensors.insert(
345            "W".to_string(),
346            TensorEntry {
347                data_type: DataType::Float32,
348                shape: vec![10, 5],
349                byte_offset: 0,
350                byte_length: 100, // Wrong: should be 200
351                layout: None,
352            },
353        );
354
355        let result = validate_weights(&g, &manifest);
356        assert!(matches!(
357            result,
358            Err(ValidateError::WeightByteLengthMismatch { .. })
359        ));
360    }
361
362    #[test]
363    fn test_validate_weights_scalar_init_skipped() {
364        let mut g = new_graph_json();
365        g.consts.insert(
366            "scale".to_string(),
367            ConstDecl {
368                data_type: DataType::Float32,
369                shape: vec![1],
370                init: ConstInit::Scalar {
371                    value: serde_json::json!(1.0),
372                },
373            },
374        );
375        g.outputs.insert("x".to_string(), "x".to_string());
376
377        let manifest = WeightsManifest {
378            format: "wg-weights-manifest".to_string(),
379            version: 1,
380            endianness: "little".to_string(),
381            tensors: BTreeMap::new(),
382        };
383
384        // Should succeed because scalar init doesn't require weights manifest entry
385        assert!(validate_weights(&g, &manifest).is_ok());
386    }
387}