1use std::collections::HashMap;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use half::bf16;
9use safetensors::tensor::Dtype as StDtype;
10use safetensors::SafeTensors;
11use serde::Deserialize;
12use thiserror::Error;
13
14use crate::ast::{ConstInit, DataType as AstDataType, GraphJson};
15
16pub const DEFAULT_PATH_JSON: &str = "model.json";
18pub const DEFAULT_PATH_WEIGHTS: &str = "model.weights";
20pub const DEFAULT_PATH_SAFETENSORS: &str = "model.safetensors";
22pub const DEFAULT_PATH_MANIFEST: &str = "manifest.json";
24
25#[derive(Debug, Error)]
27pub enum WeightResolveError {
28 #[error("failed to read `{path}`: {source}")]
30 ReadFile {
31 path: PathBuf,
32 #[source]
33 source: std::io::Error,
34 },
35 #[error("failed to parse manifest JSON at `{path}`: {source}")]
37 ManifestJson {
38 path: PathBuf,
39 #[source]
40 source: serde_json::Error,
41 },
42 #[error("[safetensors] {0}")]
44 Safetensors(String),
45 #[error("[manifest-weights] {0}")]
47 Manifest(String),
48 #[error("[weights] {0}")]
50 Missing(String),
51}
52
53fn graph_has_external_weight_refs(graph_json: &GraphJson) -> bool {
54 graph_json
55 .consts
56 .values()
57 .any(|c| matches!(c.init, ConstInit::Weights { .. }))
58}
59
60#[inline]
62fn sanitize_weight_key(name: &str) -> String {
63 name.replace("::", "__").replace('.', "_")
64}
65
66fn safetensors_st_dtype_matches_ast(st: StDtype, ast: &AstDataType) -> bool {
67 matches!(
68 (ast, st),
69 (AstDataType::Float32, StDtype::F32)
70 | (AstDataType::Float16, StDtype::F16)
71 | (AstDataType::Int32, StDtype::I32)
72 | (AstDataType::Uint32, StDtype::U32)
73 | (AstDataType::Int64, StDtype::I64)
74 | (AstDataType::Uint64, StDtype::U64)
75 | (AstDataType::Int8, StDtype::I8)
76 | (AstDataType::Uint8, StDtype::U8)
77 )
78}
79
80fn st_shape_matches_const(st_shape: &[usize], const_shape: &[u32]) -> bool {
81 if st_shape.len() != const_shape.len() {
82 return false;
83 }
84 st_shape
85 .iter()
86 .zip(const_shape.iter())
87 .all(|(&s, &c)| s as u32 == c)
88}
89
90fn bf16_bytes_to_f32_le_bytes(data: &[u8]) -> Result<Vec<u8>, WeightResolveError> {
92 if !data.len().is_multiple_of(2) {
93 return Err(WeightResolveError::Safetensors(format!(
94 "BF16 data length {} is not a multiple of 2",
95 data.len()
96 )));
97 }
98 let mut out = Vec::with_capacity(data.len() * 2);
99 for chunk in data.chunks_exact(2) {
100 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
101 let v = bf16::from_bits(bits).to_f32();
102 out.extend_from_slice(&v.to_le_bytes());
103 }
104 Ok(out)
105}
106
107fn safetensors_sanitized_name_map(
108 st: &SafeTensors<'_>,
109) -> Result<HashMap<String, String>, WeightResolveError> {
110 let mut out: HashMap<String, String> = HashMap::new();
111 for name in st.names() {
112 let sanitized = sanitize_weight_key(name);
113 if let Some(prev) = out.insert(sanitized.clone(), name.to_string()) {
114 if prev.as_str() != name {
115 return Err(WeightResolveError::Safetensors(format!(
116 "ambiguous sanitized tensor name `{sanitized}` (both `{prev}` and `{name}`)"
117 )));
118 }
119 }
120 }
121 Ok(out)
122}
123
124fn resolve_tensor_view<'a>(
125 st: &'a SafeTensors<'a>,
126 sanitized_map: &HashMap<String, String>,
127 r#ref: &str,
128) -> Result<safetensors::tensor::TensorView<'a>, WeightResolveError> {
129 if let Ok(v) = st.tensor(r#ref) {
130 return Ok(v);
131 }
132 let orig = sanitized_map.get(r#ref).ok_or_else(|| {
133 WeightResolveError::Safetensors(format!("tensor `{ref}` not found in safetensors archive"))
134 })?;
135 st.tensor(orig.as_str())
136 .map_err(|e| WeightResolveError::Safetensors(format!("tensor `{ref}` (via `{orig}`): {e}")))
137}
138
139fn inline_weights_from_safetensors(
140 graph_json: &mut GraphJson,
141 safetensors_path: &Path,
142) -> Result<(), WeightResolveError> {
143 let weight_ref_count = graph_json
144 .consts
145 .values()
146 .filter(|c| matches!(c.init, ConstInit::Weights { .. }))
147 .count();
148 eprintln!(
149 "[webnn-graph] resolve safetensors: path=`{}` weight_ref_count={}",
150 safetensors_path.display(),
151 weight_ref_count
152 );
153
154 let bytes = fs::read(safetensors_path).map_err(|source| WeightResolveError::ReadFile {
155 path: safetensors_path.to_path_buf(),
156 source,
157 })?;
158 let st = SafeTensors::deserialize(&bytes).map_err(|e| {
159 WeightResolveError::Safetensors(format!("`{}`: {e}", safetensors_path.display()))
160 })?;
161 let sanitized_map = safetensors_sanitized_name_map(&st)?;
162
163 for (const_name, const_decl) in graph_json.consts.iter_mut() {
164 let ConstInit::Weights { r#ref } = &const_decl.init else {
165 continue;
166 };
167 let view = match resolve_tensor_view(&st, &sanitized_map, r#ref) {
168 Ok(v) => v,
169 Err(e) => {
170 eprintln!(
171 "[webnn-graph] warning: safetensors could not resolve weight ref `{ref}` \
172 (constant `{const_name}`) from `{}`: {e}",
173 safetensors_path.display()
174 );
175 return Err(e);
176 }
177 };
178 if !st_shape_matches_const(view.shape(), &const_decl.shape) {
179 let msg = format!(
180 "shape mismatch for weight `{ref}` (constant `{const_name}`): graph {:?} vs safetensors {:?}",
181 const_decl.shape,
182 view.shape()
183 );
184 eprintln!(
185 "[webnn-graph] warning: safetensors could not resolve weight `{ref}` \
186 (constant `{const_name}`) from `{}`: {msg}",
187 safetensors_path.display()
188 );
189 return Err(WeightResolveError::Safetensors(msg));
190 }
191
192 let st_dtype = view.dtype();
193 let raw = view.data();
194 let bytes = if safetensors_st_dtype_matches_ast(st_dtype, &const_decl.data_type) {
195 raw.to_vec()
196 } else if matches!(
197 (&const_decl.data_type, st_dtype),
198 (AstDataType::Float32, StDtype::BF16)
199 ) {
200 let elem_count: usize = const_decl.shape.iter().map(|&x| x as usize).product();
201 let expected = elem_count.checked_mul(2).ok_or_else(|| {
202 WeightResolveError::Safetensors(format!(
203 "element count overflow for weight `{ref}` (constant `{const_name}`)"
204 ))
205 })?;
206 if raw.len() != expected {
207 return Err(WeightResolveError::Safetensors(format!(
208 "BF16 tensor `{ref}` (constant `{const_name}`): byte length {} != expected {} ({} BF16 elements)",
209 raw.len(),
210 expected,
211 elem_count
212 )));
213 }
214 eprintln!(
215 "[webnn-graph] safetensors: converting BF16 → float32 for weight `{ref}` (constant `{const_name}`)"
216 );
217 bf16_bytes_to_f32_le_bytes(raw)?
218 } else {
219 let msg = format!(
220 "dtype mismatch for weight `{ref}` (constant `{const_name}`): graph declares {:?} but safetensors has {:?}",
221 const_decl.data_type,
222 st_dtype
223 );
224 eprintln!(
225 "[webnn-graph] warning: safetensors could not resolve weight `{ref}` \
226 (constant `{const_name}`) from `{}`: {msg}",
227 safetensors_path.display()
228 );
229 return Err(WeightResolveError::Safetensors(msg));
230 };
231
232 const_decl.init = ConstInit::InlineBytes { bytes };
233 }
234
235 let still_count = graph_json
236 .consts
237 .values()
238 .filter(|c| matches!(c.init, ConstInit::Weights { .. }))
239 .count();
240 if still_count > 0 {
241 return Err(WeightResolveError::Safetensors(format!(
242 "safetensors `{}` did not provide all tensors referenced by the graph ({still_count} still missing)",
243 safetensors_path.display()
244 )));
245 }
246
247 Ok(())
248}
249
250#[derive(Debug, Deserialize)]
252struct FlexibleManifest {
253 #[serde(default)]
254 tensors: HashMap<String, FlexibleTensorEntry>,
255}
256
257#[derive(Debug, Deserialize, Clone)]
258struct FlexibleTensorEntry {
259 #[serde(rename = "byteOffset")]
260 byte_offset: u64,
261 #[serde(rename = "byteLength")]
262 byte_length: u64,
263}
264
265fn inline_weights_from_manifest(
266 graph_json: &mut GraphJson,
267 manifest_path: &Path,
268 weights_path: &Path,
269) -> Result<(), WeightResolveError> {
270 let manifest_text =
271 fs::read_to_string(manifest_path).map_err(|source| WeightResolveError::ReadFile {
272 path: manifest_path.to_path_buf(),
273 source,
274 })?;
275 let weights_bytes = fs::read(weights_path).map_err(|source| WeightResolveError::ReadFile {
276 path: weights_path.to_path_buf(),
277 source,
278 })?;
279
280 let manifest: FlexibleManifest = serde_json::from_str(&manifest_text).map_err(|source| {
281 WeightResolveError::ManifestJson {
282 path: manifest_path.to_path_buf(),
283 source,
284 }
285 })?;
286
287 let mut manifest_by_sanitized: HashMap<String, Vec<FlexibleTensorEntry>> = HashMap::new();
288 for (name, entry) in &manifest.tensors {
289 let sanitized = sanitize_weight_key(name);
290 manifest_by_sanitized
291 .entry(sanitized)
292 .or_default()
293 .push(entry.clone());
294 }
295
296 for (const_name, const_decl) in graph_json.consts.iter_mut() {
297 let ConstInit::Weights { r#ref } = &const_decl.init else {
298 continue;
299 };
300 let entry = manifest
301 .tensors
302 .get(r#ref)
303 .cloned()
304 .or_else(|| {
305 manifest_by_sanitized.get(r#ref).and_then(|entries| {
306 if entries.len() == 1 {
307 Some(entries[0].clone())
308 } else {
309 None
310 }
311 })
312 })
313 .ok_or_else(|| {
314 WeightResolveError::Manifest(format!(
315 "no manifest tensor entry for weight ref `{ref}` (constant `{const_name}`)"
316 ))
317 })?;
318
319 let start = usize::try_from(entry.byte_offset).map_err(|_| {
320 WeightResolveError::Manifest(format!(
321 "byteOffset {} for `{ref}` does not fit in usize",
322 entry.byte_offset
323 ))
324 })?;
325 let len = usize::try_from(entry.byte_length).map_err(|_| {
326 WeightResolveError::Manifest(format!(
327 "byteLength {} for `{ref}` does not fit in usize",
328 entry.byte_length
329 ))
330 })?;
331 let end = start.checked_add(len).ok_or_else(|| {
332 WeightResolveError::Manifest(format!("byte range overflow for `{ref}`"))
333 })?;
334 if end > weights_bytes.len() {
335 return Err(WeightResolveError::Manifest(format!(
336 "byte range [{start}, {end}) for `{ref}` exceeds weights file length {} (`{}`)",
337 weights_bytes.len(),
338 weights_path.display()
339 )));
340 }
341 const_decl.init = ConstInit::InlineBytes {
342 bytes: weights_bytes[start..end].to_vec(),
343 };
344 }
345 Ok(())
346}
347
348fn resolve_path_relative_to_graph(graph_path: &Path, path_str: &str) -> PathBuf {
351 let p = Path::new(path_str);
352 if p.is_absolute() {
353 p.to_path_buf()
354 } else {
355 graph_path
356 .parent()
357 .unwrap_or_else(|| Path::new("."))
358 .join(path_str)
359 }
360}
361
362fn discover_sidecar_manifest(graph_path: &Path) -> Option<PathBuf> {
363 let stem = graph_path
364 .file_stem()
365 .and_then(|s| s.to_str())
366 .unwrap_or_default();
367 [
368 graph_path.with_file_name(format!("{stem}.manifest.json")),
369 graph_path.with_file_name(DEFAULT_PATH_MANIFEST),
370 ]
371 .into_iter()
372 .find(|p| p.exists())
373}
374
375fn discover_weights_file(graph_path: &Path) -> Option<PathBuf> {
378 let stem = graph_path
379 .file_stem()
380 .and_then(|s| s.to_str())
381 .unwrap_or_default();
382 [
383 graph_path.with_file_name(format!("{stem}.safetensors")),
384 graph_path.with_file_name(format!("{stem}.weights")),
385 graph_path.with_file_name(DEFAULT_PATH_SAFETENSORS),
386 graph_path.with_file_name(DEFAULT_PATH_WEIGHTS),
387 ]
388 .into_iter()
389 .find(|p| p.exists())
390}
391
392fn path_looks_like_safetensors(path: &Path) -> bool {
394 path.extension().and_then(|s| s.to_str()).is_some_and(|e| {
395 e.eq_ignore_ascii_case("safetensors") || e.eq_ignore_ascii_case("safetensor")
396 })
397}
398
399pub fn resolve_external_weights(
424 graph_json: &mut GraphJson,
425 graph_path: &Path,
426 weights_path: Option<&str>,
427 manifest_path: Option<&str>,
428) -> Result<(), WeightResolveError> {
429 eprintln!(
430 "[webnn graph] resolve external weights: graph={}, weights_path={}, manifest_path={}",
431 graph_path.display(),
432 weights_path.unwrap_or("<discover next to graph>"),
433 manifest_path.unwrap_or("<discover next to graph>"),
434 );
435
436 if !graph_has_external_weight_refs(graph_json) {
437 return Ok(());
438 }
439
440 let stem = graph_path
441 .file_stem()
442 .and_then(|s| s.to_str())
443 .unwrap_or_default();
444
445 let wp = if let Some(s) = weights_path {
446 let p = resolve_path_relative_to_graph(graph_path, s);
447 if !p.exists() {
448 return Err(WeightResolveError::Missing(format!(
449 "weights path `{}` does not exist",
450 p.display()
451 )));
452 }
453 p
454 } else {
455 discover_weights_file(graph_path).ok_or_else(|| {
456 WeightResolveError::Missing(format!(
457 "no weights file found next to `{0}`; expected `{1}.safetensors`, `{1}.weights`, \
458 `{DEFAULT_PATH_SAFETENSORS}`, or `{DEFAULT_PATH_WEIGHTS}`, or pass `weights_path`",
459 graph_path.display(),
460 stem,
461 ))
462 })?
463 };
464
465 if path_looks_like_safetensors(&wp) {
466 return inline_weights_from_safetensors(graph_json, &wp);
467 }
468
469 let mp = if let Some(s) = manifest_path {
470 let p = resolve_path_relative_to_graph(graph_path, s);
471 if !p.exists() {
472 return Err(WeightResolveError::Missing(format!(
473 "manifest path `{}` does not exist",
474 p.display()
475 )));
476 }
477 p
478 } else {
479 discover_sidecar_manifest(graph_path).ok_or_else(|| {
480 WeightResolveError::Missing(format!(
481 "weights blob `{0}` requires a manifest; pass `manifest_path` or place `{1}.manifest.json` / \
482 `{DEFAULT_PATH_MANIFEST}` next to `{2}`",
483 wp.display(),
484 stem,
485 graph_path.display()
486 ))
487 })?
488 };
489
490 inline_weights_from_manifest(graph_json, &mp, &wp)
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use safetensors::tensor::TensorView;
497 use safetensors::{serialize, Dtype};
498 use tempfile::TempDir;
499
500 fn write_safetensors_f32(path: &Path, tensor_name: &str, shape: Vec<usize>, data: &[u8]) {
501 let view = TensorView::new(Dtype::F32, shape, data).unwrap();
502 let bytes = serialize(vec![(tensor_name.to_string(), view)], None).unwrap();
503 std::fs::write(path, bytes).unwrap();
504 }
505
506 fn write_safetensors_bf16(path: &Path, tensor_name: &str, shape: Vec<usize>, data: &[u8]) {
507 let view = TensorView::new(Dtype::BF16, shape, data).unwrap();
508 let bytes = serialize(vec![(tensor_name.to_string(), view)], None).unwrap();
509 std::fs::write(path, bytes).unwrap();
510 }
511
512 #[test]
513 fn manifest_and_weights_inline() {
514 let temp_dir = TempDir::new().unwrap();
515 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
516 let manifest_path = temp_dir.path().join("model.manifest.json");
517 let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
518
519 let graph_content = r#"{
520 "format": "webnn-graph-json",
521 "version": 1,
522 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
523 "consts": {
524 "weight": {
525 "dataType": "float32",
526 "shape": [2],
527 "init": { "kind": "weights", "ref": "weight" }
528 }
529 },
530 "nodes": [],
531 "outputs": { "y": "x" }
532 }"#;
533
534 let manifest_content = r#"{
535 "format": "webnn-weights-manifest",
536 "version": 1,
537 "endianness": "little",
538 "tensors": {
539 "weight": {
540 "dataType": "float32",
541 "shape": [2],
542 "byteOffset": 0,
543 "byteLength": 8
544 }
545 }
546 }"#;
547
548 let weights_data: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
549 std::fs::write(&graph_path, graph_content).unwrap();
550 std::fs::write(&manifest_path, manifest_content).unwrap();
551 std::fs::write(&weights_path, &weights_data).unwrap();
552
553 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
554 resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
555 match &graph.consts["weight"].init {
556 ConstInit::InlineBytes { bytes } => assert_eq!(bytes.len(), 8),
557 other => panic!("expected inline bytes, got {:?}", other),
558 }
559 }
560
561 #[test]
562 fn explicit_manifest_and_weights_paths() {
563 let temp_dir = TempDir::new().unwrap();
564 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
565 let manifest_path = temp_dir.path().join("custom.manifest.json");
566 let weights_path = temp_dir.path().join("blob.weights");
567
568 let graph_content = r#"{
569 "format": "webnn-graph-json",
570 "version": 1,
571 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
572 "consts": {
573 "weight": {
574 "dataType": "float32",
575 "shape": [2],
576 "init": { "kind": "weights", "ref": "weight" }
577 }
578 },
579 "nodes": [],
580 "outputs": { "y": "x" }
581 }"#;
582
583 let manifest_content = r#"{
584 "format": "webnn-weights-manifest",
585 "version": 1,
586 "endianness": "little",
587 "tensors": {
588 "weight": {
589 "dataType": "float32",
590 "shape": [2],
591 "byteOffset": 0,
592 "byteLength": 8
593 }
594 }
595 }"#;
596
597 let weights_data: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
598 std::fs::write(&graph_path, graph_content).unwrap();
599 std::fs::write(&manifest_path, manifest_content).unwrap();
600 std::fs::write(&weights_path, &weights_data).unwrap();
601
602 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
603 resolve_external_weights(
604 &mut graph,
605 &graph_path,
606 Some("blob.weights"),
607 Some("custom.manifest.json"),
608 )
609 .unwrap();
610 match &graph.consts["weight"].init {
611 ConstInit::InlineBytes { bytes } => assert_eq!(bytes.len(), 8),
612 other => panic!("expected inline bytes, got {:?}", other),
613 }
614 }
615
616 #[test]
617 fn explicit_safetensors_weights_path() {
618 let temp_dir = TempDir::new().unwrap();
619 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
620 let st_path = temp_dir.path().join("custom.safetensors");
621
622 let graph_content = r#"{
623 "format": "webnn-graph-json",
624 "version": 1,
625 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
626 "consts": {
627 "weight": {
628 "dataType": "float32",
629 "shape": [2],
630 "init": { "kind": "weights", "ref": "weight" }
631 }
632 },
633 "nodes": [],
634 "outputs": { "y": "x" }
635 }"#;
636
637 let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
638 std::fs::write(&graph_path, graph_content).unwrap();
639 write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
640
641 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
642 resolve_external_weights(&mut graph, &graph_path, Some("custom.safetensors"), None)
643 .unwrap();
644 match &graph.consts["weight"].init {
645 ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
646 other => panic!("expected inline bytes, got {:?}", other),
647 }
648 }
649
650 #[test]
651 fn manifest_arg_ignored_when_weights_path_is_safetensors() {
652 let temp_dir = TempDir::new().unwrap();
653 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
654 let st_path = temp_dir.path().join("weights.safetensors");
655
656 let graph_content = r#"{
657 "format": "webnn-graph-json",
658 "version": 1,
659 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
660 "consts": {
661 "weight": {
662 "dataType": "float32",
663 "shape": [2],
664 "init": { "kind": "weights", "ref": "weight" }
665 }
666 },
667 "nodes": [],
668 "outputs": { "y": "x" }
669 }"#;
670
671 let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
672 std::fs::write(&graph_path, graph_content).unwrap();
673 write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
674
675 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
676 resolve_external_weights(
677 &mut graph,
678 &graph_path,
679 Some("weights.safetensors"),
680 Some("this_manifest_is_not_read.json"),
681 )
682 .unwrap();
683 match &graph.consts["weight"].init {
684 ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
685 other => panic!("expected inline bytes, got {:?}", other),
686 }
687 }
688
689 #[test]
690 fn safetensors_inline() {
691 let temp_dir = TempDir::new().unwrap();
692 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
693 let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
694
695 let graph_content = r#"{
696 "format": "webnn-graph-json",
697 "version": 1,
698 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
699 "consts": {
700 "weight": {
701 "dataType": "float32",
702 "shape": [2],
703 "init": { "kind": "weights", "ref": "weight" }
704 }
705 },
706 "nodes": [],
707 "outputs": { "y": "x" }
708 }"#;
709
710 let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
711 std::fs::write(&graph_path, graph_content).unwrap();
712 write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
713
714 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
715 resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
716 match &graph.consts["weight"].init {
717 ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
718 other => panic!("expected inline bytes, got {:?}", other),
719 }
720 }
721
722 #[test]
723 fn out_of_bounds_manifest_errors() {
724 let temp_dir = TempDir::new().unwrap();
725 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
726 let manifest_path = temp_dir.path().join(DEFAULT_PATH_MANIFEST);
727 let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
728
729 let graph_content = r#"{
730 "format": "webnn-graph-json",
731 "version": 1,
732 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
733 "consts": {
734 "weight": {
735 "dataType": "float32",
736 "shape": [2],
737 "init": { "kind": "weights", "ref": "weight" }
738 }
739 },
740 "nodes": [],
741 "outputs": { "y": "x" }
742 }"#;
743
744 let manifest_content = r#"{
745 "format": "webnn-weights-manifest",
746 "version": 1,
747 "tensors": {
748 "weight": {
749 "dataType": "float32",
750 "shape": [2],
751 "byteOffset": 0,
752 "byteLength": 100
753 }
754 }
755 }"#;
756
757 std::fs::write(&graph_path, graph_content).unwrap();
758 std::fs::write(&manifest_path, manifest_content).unwrap();
759 std::fs::write(&weights_path, vec![0u8; 8]).unwrap();
760
761 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
762 let err = resolve_external_weights(&mut graph, &graph_path, None, None).unwrap_err();
763 assert!(matches!(err, WeightResolveError::Manifest(_)));
764 }
765
766 #[test]
767 fn safetensors_preferred_over_invalid_manifest() {
768 let temp_dir = TempDir::new().unwrap();
769 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
770 let manifest_path = temp_dir.path().join(DEFAULT_PATH_MANIFEST);
771 let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
772 let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
773
774 let graph_content = r#"{
775 "format": "webnn-graph-json",
776 "version": 1,
777 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
778 "consts": {
779 "weight": {
780 "dataType": "float32",
781 "shape": [2],
782 "init": { "kind": "weights", "ref": "weight" }
783 }
784 },
785 "nodes": [],
786 "outputs": { "y": "x" }
787 }"#;
788
789 std::fs::write(&graph_path, graph_content).unwrap();
790 std::fs::write(&manifest_path, "{ not valid manifest json").unwrap();
791 std::fs::write(&weights_path, [0u8; 8]).unwrap();
792 write_safetensors_f32(
793 &st_path,
794 "weight",
795 vec![2],
796 &[0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40],
797 );
798
799 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
800 resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
801 }
802
803 #[test]
804 fn safetensors_bf16_converts_to_float32_for_graph_constants() {
805 use half::bf16;
806
807 let temp_dir = TempDir::new().unwrap();
808 let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
809 let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
810
811 let graph_content = r#"{
812 "format": "webnn-graph-json",
813 "version": 1,
814 "inputs": { "x": { "dataType": "float32", "shape": [2] } },
815 "consts": {
816 "weight": {
817 "dataType": "float32",
818 "shape": [2],
819 "init": { "kind": "weights", "ref": "weight" }
820 }
821 },
822 "nodes": [],
823 "outputs": { "y": "x" }
824 }"#;
825
826 let mut bf16_bytes = Vec::new();
827 bf16_bytes.extend_from_slice(&bf16::from_f32(1.0f32).to_bits().to_le_bytes());
828 bf16_bytes.extend_from_slice(&bf16::from_f32(2.0f32).to_bits().to_le_bytes());
829
830 std::fs::write(&graph_path, graph_content).unwrap();
831 write_safetensors_bf16(&st_path, "weight", vec![2], &bf16_bytes);
832
833 let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
834 resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
835
836 let expected: Vec<u8> = [1.0f32, 2.0f32]
837 .iter()
838 .flat_map(|f| f.to_le_bytes())
839 .collect();
840 match &graph.consts["weight"].init {
841 ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &expected),
842 other => panic!("expected inline bytes, got {:?}", other),
843 }
844 }
845}