1use std::collections::BTreeSet;
2
3use thiserror::Error;
4
5use crate::ast::{ConstInit, GraphJson};
6use crate::weights::{dtype_size, numel, WeightsManifest};
7
8#[derive(Debug, Error)]
9pub enum ValidateError {
10 #[error("unsupported format/version: {0} v{1}")]
11 BadFormat(String, u32),
12
13 #[error("outputs must be non-empty")]
14 EmptyOutputs,
15
16 #[error("duplicate node id: {0}")]
17 DuplicateNodeId(String),
18
19 #[error("unknown reference '{ref_}' used in node '{node}'")]
20 UnknownRef { node: String, ref_: String },
21
22 #[error("output '{out}' references unknown value '{ref_}'")]
23 BadOutputRef { out: String, ref_: String },
24
25 #[error("missing weights manifest entry for const ref '{0}'")]
26 MissingWeight(String),
27
28 #[error("weights type/shape mismatch for '{ref_}'")]
29 WeightMismatch { ref_: String },
30
31 #[error("weights byteLength mismatch for '{ref_}': expected {expected} got {got}")]
32 WeightByteLengthMismatch {
33 ref_: String,
34 expected: u64,
35 got: u64,
36 },
37}
38
39pub fn validate_graph(g: &GraphJson) -> Result<(), ValidateError> {
40 if g.format != "webnn-graph-json" || g.version != 1 {
41 return Err(ValidateError::BadFormat(g.format.clone(), g.version));
42 }
43 if g.outputs.is_empty() {
44 return Err(ValidateError::EmptyOutputs);
45 }
46
47 let mut known: BTreeSet<String> = g.inputs.keys().cloned().collect();
48 known.extend(g.consts.keys().cloned());
49
50 let mut ids = BTreeSet::new();
51
52 for n in &g.nodes {
53 if !ids.insert(n.id.clone()) {
54 return Err(ValidateError::DuplicateNodeId(n.id.clone()));
55 }
56 for r in &n.inputs {
57 if !known.contains(r) {
58 return Err(ValidateError::UnknownRef {
59 node: n.id.clone(),
60 ref_: r.clone(),
61 });
62 }
63 }
64 known.insert(n.id.clone());
65 if let Some(outs) = &n.outputs {
66 for o in outs {
67 known.insert(o.clone());
68 }
69 }
70 }
71
72 for (out, r) in &g.outputs {
73 if !known.contains(r) {
74 return Err(ValidateError::BadOutputRef {
75 out: out.clone(),
76 ref_: r.clone(),
77 });
78 }
79 }
80 Ok(())
81}
82
83pub fn validate_weights(g: &GraphJson, m: &WeightsManifest) -> Result<(), ValidateError> {
84 for c in g.consts.values() {
85 if let ConstInit::Weights { r#ref } = &c.init {
86 let entry = m
87 .tensors
88 .get(r#ref)
89 .ok_or_else(|| ValidateError::MissingWeight(r#ref.clone()))?;
90
91 if entry.data_type != c.data_type || entry.shape != c.shape {
92 return Err(ValidateError::WeightMismatch {
93 ref_: r#ref.clone(),
94 });
95 }
96
97 let expected = dtype_size(&c.data_type) * numel(&c.shape);
98 if entry.byte_length != expected {
99 return Err(ValidateError::WeightByteLengthMismatch {
100 ref_: r#ref.clone(),
101 expected,
102 got: entry.byte_length,
103 });
104 }
105 }
106 }
107 Ok(())
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::ast::{new_graph_json, ConstDecl, DataType, Node, OperandDesc};
114 use crate::weights::TensorEntry;
115 use std::collections::BTreeMap;
116
117 #[test]
118 fn test_validate_graph_success() {
119 let mut g = new_graph_json();
120 g.inputs.insert(
121 "x".to_string(),
122 OperandDesc {
123 data_type: DataType::Float32,
124 shape: vec![1, 10],
125 },
126 );
127 g.nodes.push(Node {
128 id: "result".to_string(),
129 op: "relu".to_string(),
130 inputs: vec!["x".to_string()],
131 options: serde_json::Map::new(),
132 outputs: None,
133 });
134 g.outputs.insert("result".to_string(), "result".to_string());
135
136 assert!(validate_graph(&g).is_ok());
137 }
138
139 #[test]
140 fn test_validate_graph_bad_format() {
141 let mut g = new_graph_json();
142 g.format = "invalid".to_string();
143 g.outputs.insert("x".to_string(), "x".to_string());
144
145 let result = validate_graph(&g);
146 assert!(matches!(result, Err(ValidateError::BadFormat(_, _))));
147 }
148
149 #[test]
150 fn test_validate_graph_empty_outputs() {
151 let g = new_graph_json();
152 let result = validate_graph(&g);
153 assert!(matches!(result, Err(ValidateError::EmptyOutputs)));
154 }
155
156 #[test]
157 fn test_validate_graph_duplicate_node_id() {
158 let mut g = new_graph_json();
159 g.inputs.insert(
160 "x".to_string(),
161 OperandDesc {
162 data_type: DataType::Float32,
163 shape: vec![1],
164 },
165 );
166 g.nodes.push(Node {
167 id: "result".to_string(),
168 op: "relu".to_string(),
169 inputs: vec!["x".to_string()],
170 options: serde_json::Map::new(),
171 outputs: None,
172 });
173 g.nodes.push(Node {
174 id: "result".to_string(),
175 op: "sigmoid".to_string(),
176 inputs: vec!["x".to_string()],
177 options: serde_json::Map::new(),
178 outputs: None,
179 });
180 g.outputs.insert("result".to_string(), "result".to_string());
181
182 let result = validate_graph(&g);
183 assert!(matches!(result, Err(ValidateError::DuplicateNodeId(_))));
184 }
185
186 #[test]
187 fn test_validate_graph_unknown_ref() {
188 let mut g = new_graph_json();
189 g.inputs.insert(
190 "x".to_string(),
191 OperandDesc {
192 data_type: DataType::Float32,
193 shape: vec![1],
194 },
195 );
196 g.nodes.push(Node {
197 id: "result".to_string(),
198 op: "add".to_string(),
199 inputs: vec!["x".to_string(), "unknown".to_string()],
200 options: serde_json::Map::new(),
201 outputs: None,
202 });
203 g.outputs.insert("result".to_string(), "result".to_string());
204
205 let result = validate_graph(&g);
206 assert!(matches!(result, Err(ValidateError::UnknownRef { .. })));
207 }
208
209 #[test]
210 fn test_validate_graph_bad_output_ref() {
211 let mut g = new_graph_json();
212 g.inputs.insert(
213 "x".to_string(),
214 OperandDesc {
215 data_type: DataType::Float32,
216 shape: vec![1],
217 },
218 );
219 g.outputs
220 .insert("out".to_string(), "nonexistent".to_string());
221
222 let result = validate_graph(&g);
223 assert!(matches!(result, Err(ValidateError::BadOutputRef { .. })));
224 }
225
226 #[test]
227 fn test_validate_weights_success() {
228 let mut g = new_graph_json();
229 g.consts.insert(
230 "W".to_string(),
231 ConstDecl {
232 data_type: DataType::Float32,
233 shape: vec![10, 5],
234 init: ConstInit::Weights {
235 r#ref: "W".to_string(),
236 },
237 },
238 );
239 g.outputs.insert("x".to_string(), "x".to_string());
240
241 let mut manifest = WeightsManifest {
242 format: "wg-weights-manifest".to_string(),
243 version: 1,
244 endianness: "little".to_string(),
245 tensors: BTreeMap::new(),
246 };
247 manifest.tensors.insert(
248 "W".to_string(),
249 TensorEntry {
250 data_type: DataType::Float32,
251 shape: vec![10, 5],
252 byte_offset: 0,
253 byte_length: 200, layout: None,
255 },
256 );
257
258 assert!(validate_weights(&g, &manifest).is_ok());
259 }
260
261 #[test]
262 fn test_validate_weights_missing_weight() {
263 let mut g = new_graph_json();
264 g.consts.insert(
265 "W".to_string(),
266 ConstDecl {
267 data_type: DataType::Float32,
268 shape: vec![10, 5],
269 init: ConstInit::Weights {
270 r#ref: "W".to_string(),
271 },
272 },
273 );
274 g.outputs.insert("x".to_string(), "x".to_string());
275
276 let manifest = WeightsManifest {
277 format: "wg-weights-manifest".to_string(),
278 version: 1,
279 endianness: "little".to_string(),
280 tensors: BTreeMap::new(),
281 };
282
283 let result = validate_weights(&g, &manifest);
284 assert!(matches!(result, Err(ValidateError::MissingWeight(_))));
285 }
286
287 #[test]
288 fn test_validate_weights_type_mismatch() {
289 let mut g = new_graph_json();
290 g.consts.insert(
291 "W".to_string(),
292 ConstDecl {
293 data_type: DataType::Float32,
294 shape: vec![10, 5],
295 init: ConstInit::Weights {
296 r#ref: "W".to_string(),
297 },
298 },
299 );
300 g.outputs.insert("x".to_string(), "x".to_string());
301
302 let mut manifest = WeightsManifest {
303 format: "wg-weights-manifest".to_string(),
304 version: 1,
305 endianness: "little".to_string(),
306 tensors: BTreeMap::new(),
307 };
308 manifest.tensors.insert(
309 "W".to_string(),
310 TensorEntry {
311 data_type: DataType::Float16, shape: vec![10, 5],
313 byte_offset: 0,
314 byte_length: 100,
315 layout: None,
316 },
317 );
318
319 let result = validate_weights(&g, &manifest);
320 assert!(matches!(result, Err(ValidateError::WeightMismatch { .. })));
321 }
322
323 #[test]
324 fn test_validate_weights_byte_length_mismatch() {
325 let mut g = new_graph_json();
326 g.consts.insert(
327 "W".to_string(),
328 ConstDecl {
329 data_type: DataType::Float32,
330 shape: vec![10, 5],
331 init: ConstInit::Weights {
332 r#ref: "W".to_string(),
333 },
334 },
335 );
336 g.outputs.insert("x".to_string(), "x".to_string());
337
338 let mut manifest = WeightsManifest {
339 format: "wg-weights-manifest".to_string(),
340 version: 1,
341 endianness: "little".to_string(),
342 tensors: BTreeMap::new(),
343 };
344 manifest.tensors.insert(
345 "W".to_string(),
346 TensorEntry {
347 data_type: DataType::Float32,
348 shape: vec![10, 5],
349 byte_offset: 0,
350 byte_length: 100, layout: None,
352 },
353 );
354
355 let result = validate_weights(&g, &manifest);
356 assert!(matches!(
357 result,
358 Err(ValidateError::WeightByteLengthMismatch { .. })
359 ));
360 }
361
362 #[test]
363 fn test_validate_weights_scalar_init_skipped() {
364 let mut g = new_graph_json();
365 g.consts.insert(
366 "scale".to_string(),
367 ConstDecl {
368 data_type: DataType::Float32,
369 shape: vec![1],
370 init: ConstInit::Scalar {
371 value: serde_json::json!(1.0),
372 },
373 },
374 );
375 g.outputs.insert("x".to_string(), "x".to_string());
376
377 let manifest = WeightsManifest {
378 format: "wg-weights-manifest".to_string(),
379 version: 1,
380 endianness: "little".to_string(),
381 tensors: BTreeMap::new(),
382 };
383
384 assert!(validate_weights(&g, &manifest).is_ok());
386 }
387}