Skip to main content

webnn_graph/
external_weights.rs

1//! Resolve `@weights` / [`ConstInit::Weights`] using sidecar files
2//! next to a graph path (SafeTensors or manifest + raw weights blob).
3
4use std::collections::HashMap;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use half::bf16;
9use safetensors::tensor::Dtype as StDtype;
10use safetensors::SafeTensors;
11use serde::Deserialize;
12use thiserror::Error;
13
14use crate::ast::{ConstInit, DataType as AstDataType, GraphJson};
15
16/// Default graph JSON basename (typical sidecar layout next to weights / manifest).
17pub const DEFAULT_PATH_JSON: &str = "model.json";
18/// Default raw weights blob basename when not using a stem-prefixed `*.weights` file.
19pub const DEFAULT_PATH_WEIGHTS: &str = "model.weights";
20/// Default SafeTensors archive basename when not using a stem-prefixed `*.safetensors` file.
21pub const DEFAULT_PATH_SAFETENSORS: &str = "model.safetensors";
22/// Default weights manifest basename when not using a stem-prefixed `*.manifest.json` file.
23pub const DEFAULT_PATH_MANIFEST: &str = "manifest.json";
24
25/// Failure while resolving external weights for a [`GraphJson`].
26#[derive(Debug, Error)]
27pub enum WeightResolveError {
28    /// Could not read a required file from disk.
29    #[error("failed to read `{path}`: {source}")]
30    ReadFile {
31        path: PathBuf,
32        #[source]
33        source: std::io::Error,
34    },
35    /// Manifest JSON is invalid.
36    #[error("failed to parse manifest JSON at `{path}`: {source}")]
37    ManifestJson {
38        path: PathBuf,
39        #[source]
40        source: serde_json::Error,
41    },
42    /// SafeTensors–specific validation or parse error.
43    #[error("[safetensors] {0}")]
44    Safetensors(String),
45    /// Manifest + weights blob resolution error.
46    #[error("[manifest-weights] {0}")]
47    Manifest(String),
48    /// No usable weight source was found next to the graph.
49    #[error("[weights] {0}")]
50    Missing(String),
51}
52
53fn graph_has_external_weight_refs(graph_json: &GraphJson) -> bool {
54    graph_json
55        .consts
56        .values()
57        .any(|c| matches!(c.init, ConstInit::Weights { .. }))
58}
59
60/// Normalizes tensor / manifest key strings for lookup when graphs use sanitized weight refs.
61#[inline]
62fn sanitize_weight_key(name: &str) -> String {
63    name.replace("::", "__").replace('.', "_")
64}
65
66fn safetensors_st_dtype_matches_ast(st: StDtype, ast: &AstDataType) -> bool {
67    matches!(
68        (ast, st),
69        (AstDataType::Float32, StDtype::F32)
70            | (AstDataType::Float16, StDtype::F16)
71            | (AstDataType::Int32, StDtype::I32)
72            | (AstDataType::Uint32, StDtype::U32)
73            | (AstDataType::Int64, StDtype::I64)
74            | (AstDataType::Uint64, StDtype::U64)
75            | (AstDataType::Int8, StDtype::I8)
76            | (AstDataType::Uint8, StDtype::U8)
77    )
78}
79
80fn st_shape_matches_const(st_shape: &[usize], const_shape: &[u32]) -> bool {
81    if st_shape.len() != const_shape.len() {
82        return false;
83    }
84    st_shape
85        .iter()
86        .zip(const_shape.iter())
87        .all(|(&s, &c)| s as u32 == c)
88}
89
90/// Convert little-endian BF16 payload to little-endian F32 (WebNN float32 constants).
91fn bf16_bytes_to_f32_le_bytes(data: &[u8]) -> Result<Vec<u8>, WeightResolveError> {
92    if !data.len().is_multiple_of(2) {
93        return Err(WeightResolveError::Safetensors(format!(
94            "BF16 data length {} is not a multiple of 2",
95            data.len()
96        )));
97    }
98    let mut out = Vec::with_capacity(data.len() * 2);
99    for chunk in data.chunks_exact(2) {
100        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
101        let v = bf16::from_bits(bits).to_f32();
102        out.extend_from_slice(&v.to_le_bytes());
103    }
104    Ok(out)
105}
106
107fn safetensors_sanitized_name_map(
108    st: &SafeTensors<'_>,
109) -> Result<HashMap<String, String>, WeightResolveError> {
110    let mut out: HashMap<String, String> = HashMap::new();
111    for name in st.names() {
112        let sanitized = sanitize_weight_key(name);
113        if let Some(prev) = out.insert(sanitized.clone(), name.to_string()) {
114            if prev.as_str() != name {
115                return Err(WeightResolveError::Safetensors(format!(
116                    "ambiguous sanitized tensor name `{sanitized}` (both `{prev}` and `{name}`)"
117                )));
118            }
119        }
120    }
121    Ok(out)
122}
123
124fn resolve_tensor_view<'a>(
125    st: &'a SafeTensors<'a>,
126    sanitized_map: &HashMap<String, String>,
127    r#ref: &str,
128) -> Result<safetensors::tensor::TensorView<'a>, WeightResolveError> {
129    if let Ok(v) = st.tensor(r#ref) {
130        return Ok(v);
131    }
132    let orig = sanitized_map.get(r#ref).ok_or_else(|| {
133        WeightResolveError::Safetensors(format!("tensor `{ref}` not found in safetensors archive"))
134    })?;
135    st.tensor(orig.as_str())
136        .map_err(|e| WeightResolveError::Safetensors(format!("tensor `{ref}` (via `{orig}`): {e}")))
137}
138
139fn inline_weights_from_safetensors(
140    graph_json: &mut GraphJson,
141    safetensors_path: &Path,
142) -> Result<(), WeightResolveError> {
143    let weight_ref_count = graph_json
144        .consts
145        .values()
146        .filter(|c| matches!(c.init, ConstInit::Weights { .. }))
147        .count();
148    eprintln!(
149        "[webnn-graph] resolve safetensors: path=`{}` weight_ref_count={}",
150        safetensors_path.display(),
151        weight_ref_count
152    );
153
154    let bytes = fs::read(safetensors_path).map_err(|source| WeightResolveError::ReadFile {
155        path: safetensors_path.to_path_buf(),
156        source,
157    })?;
158    let st = SafeTensors::deserialize(&bytes).map_err(|e| {
159        WeightResolveError::Safetensors(format!("`{}`: {e}", safetensors_path.display()))
160    })?;
161    let sanitized_map = safetensors_sanitized_name_map(&st)?;
162
163    for (const_name, const_decl) in graph_json.consts.iter_mut() {
164        let ConstInit::Weights { r#ref } = &const_decl.init else {
165            continue;
166        };
167        let view = match resolve_tensor_view(&st, &sanitized_map, r#ref) {
168            Ok(v) => v,
169            Err(e) => {
170                eprintln!(
171                    "[webnn-graph] warning: safetensors could not resolve weight ref `{ref}` \
172                     (constant `{const_name}`) from `{}`: {e}",
173                    safetensors_path.display()
174                );
175                return Err(e);
176            }
177        };
178        if !st_shape_matches_const(view.shape(), &const_decl.shape) {
179            let msg = format!(
180                "shape mismatch for weight `{ref}` (constant `{const_name}`): graph {:?} vs safetensors {:?}",
181                const_decl.shape,
182                view.shape()
183            );
184            eprintln!(
185                "[webnn-graph] warning: safetensors could not resolve weight `{ref}` \
186                 (constant `{const_name}`) from `{}`: {msg}",
187                safetensors_path.display()
188            );
189            return Err(WeightResolveError::Safetensors(msg));
190        }
191
192        let st_dtype = view.dtype();
193        let raw = view.data();
194        let bytes = if safetensors_st_dtype_matches_ast(st_dtype, &const_decl.data_type) {
195            raw.to_vec()
196        } else if matches!(
197            (&const_decl.data_type, st_dtype),
198            (AstDataType::Float32, StDtype::BF16)
199        ) {
200            let elem_count: usize = const_decl.shape.iter().map(|&x| x as usize).product();
201            let expected = elem_count.checked_mul(2).ok_or_else(|| {
202                WeightResolveError::Safetensors(format!(
203                    "element count overflow for weight `{ref}` (constant `{const_name}`)"
204                ))
205            })?;
206            if raw.len() != expected {
207                return Err(WeightResolveError::Safetensors(format!(
208                    "BF16 tensor `{ref}` (constant `{const_name}`): byte length {} != expected {} ({} BF16 elements)",
209                    raw.len(),
210                    expected,
211                    elem_count
212                )));
213            }
214            eprintln!(
215                "[webnn-graph] safetensors: converting BF16 → float32 for weight `{ref}` (constant `{const_name}`)"
216            );
217            bf16_bytes_to_f32_le_bytes(raw)?
218        } else {
219            let msg = format!(
220                "dtype mismatch for weight `{ref}` (constant `{const_name}`): graph declares {:?} but safetensors has {:?}",
221                const_decl.data_type,
222                st_dtype
223            );
224            eprintln!(
225                "[webnn-graph] warning: safetensors could not resolve weight `{ref}` \
226                 (constant `{const_name}`) from `{}`: {msg}",
227                safetensors_path.display()
228            );
229            return Err(WeightResolveError::Safetensors(msg));
230        };
231
232        const_decl.init = ConstInit::InlineBytes { bytes };
233    }
234
235    let still_count = graph_json
236        .consts
237        .values()
238        .filter(|c| matches!(c.init, ConstInit::Weights { .. }))
239        .count();
240    if still_count > 0 {
241        return Err(WeightResolveError::Safetensors(format!(
242            "safetensors `{}` did not provide all tensors referenced by the graph ({still_count} still missing)",
243            safetensors_path.display()
244        )));
245    }
246
247    Ok(())
248}
249
250/// Weight manifest JSON next to a graph (supports `webnn-weights-manifest` and related layouts).
251#[derive(Debug, Deserialize)]
252struct FlexibleManifest {
253    #[serde(default)]
254    tensors: HashMap<String, FlexibleTensorEntry>,
255}
256
257#[derive(Debug, Deserialize, Clone)]
258struct FlexibleTensorEntry {
259    #[serde(rename = "byteOffset")]
260    byte_offset: u64,
261    #[serde(rename = "byteLength")]
262    byte_length: u64,
263}
264
265fn inline_weights_from_manifest(
266    graph_json: &mut GraphJson,
267    manifest_path: &Path,
268    weights_path: &Path,
269) -> Result<(), WeightResolveError> {
270    let manifest_text =
271        fs::read_to_string(manifest_path).map_err(|source| WeightResolveError::ReadFile {
272            path: manifest_path.to_path_buf(),
273            source,
274        })?;
275    let weights_bytes = fs::read(weights_path).map_err(|source| WeightResolveError::ReadFile {
276        path: weights_path.to_path_buf(),
277        source,
278    })?;
279
280    let manifest: FlexibleManifest = serde_json::from_str(&manifest_text).map_err(|source| {
281        WeightResolveError::ManifestJson {
282            path: manifest_path.to_path_buf(),
283            source,
284        }
285    })?;
286
287    let mut manifest_by_sanitized: HashMap<String, Vec<FlexibleTensorEntry>> = HashMap::new();
288    for (name, entry) in &manifest.tensors {
289        let sanitized = sanitize_weight_key(name);
290        manifest_by_sanitized
291            .entry(sanitized)
292            .or_default()
293            .push(entry.clone());
294    }
295
296    for (const_name, const_decl) in graph_json.consts.iter_mut() {
297        let ConstInit::Weights { r#ref } = &const_decl.init else {
298            continue;
299        };
300        let entry = manifest
301            .tensors
302            .get(r#ref)
303            .cloned()
304            .or_else(|| {
305                manifest_by_sanitized.get(r#ref).and_then(|entries| {
306                    if entries.len() == 1 {
307                        Some(entries[0].clone())
308                    } else {
309                        None
310                    }
311                })
312            })
313            .ok_or_else(|| {
314                WeightResolveError::Manifest(format!(
315                    "no manifest tensor entry for weight ref `{ref}` (constant `{const_name}`)"
316                ))
317            })?;
318
319        let start = usize::try_from(entry.byte_offset).map_err(|_| {
320            WeightResolveError::Manifest(format!(
321                "byteOffset {} for `{ref}` does not fit in usize",
322                entry.byte_offset
323            ))
324        })?;
325        let len = usize::try_from(entry.byte_length).map_err(|_| {
326            WeightResolveError::Manifest(format!(
327                "byteLength {} for `{ref}` does not fit in usize",
328                entry.byte_length
329            ))
330        })?;
331        let end = start.checked_add(len).ok_or_else(|| {
332            WeightResolveError::Manifest(format!("byte range overflow for `{ref}`"))
333        })?;
334        if end > weights_bytes.len() {
335            return Err(WeightResolveError::Manifest(format!(
336                "byte range [{start}, {end}) for `{ref}` exceeds weights file length {} (`{}`)",
337                weights_bytes.len(),
338                weights_path.display()
339            )));
340        }
341        const_decl.init = ConstInit::InlineBytes {
342            bytes: weights_bytes[start..end].to_vec(),
343        };
344    }
345    Ok(())
346}
347
348/// Resolves `path_str` relative to the parent directory of `graph_path`, or as an absolute path
349/// when `path_str` is absolute.
350fn resolve_path_relative_to_graph(graph_path: &Path, path_str: &str) -> PathBuf {
351    let p = Path::new(path_str);
352    if p.is_absolute() {
353        p.to_path_buf()
354    } else {
355        graph_path
356            .parent()
357            .unwrap_or_else(|| Path::new("."))
358            .join(path_str)
359    }
360}
361
362fn discover_sidecar_manifest(graph_path: &Path) -> Option<PathBuf> {
363    let stem = graph_path
364        .file_stem()
365        .and_then(|s| s.to_str())
366        .unwrap_or_default();
367    [
368        graph_path.with_file_name(format!("{stem}.manifest.json")),
369        graph_path.with_file_name(DEFAULT_PATH_MANIFEST),
370    ]
371    .into_iter()
372    .find(|p| p.exists())
373}
374
375/// Discovers a single weights file next to `graph_path`, in order: `{stem}.safetensors`,
376/// `{stem}.weights`, [`DEFAULT_PATH_SAFETENSORS`], [`DEFAULT_PATH_WEIGHTS`].
377fn discover_weights_file(graph_path: &Path) -> Option<PathBuf> {
378    let stem = graph_path
379        .file_stem()
380        .and_then(|s| s.to_str())
381        .unwrap_or_default();
382    [
383        graph_path.with_file_name(format!("{stem}.safetensors")),
384        graph_path.with_file_name(format!("{stem}.weights")),
385        graph_path.with_file_name(DEFAULT_PATH_SAFETENSORS),
386        graph_path.with_file_name(DEFAULT_PATH_WEIGHTS),
387    ]
388    .into_iter()
389    .find(|p| p.exists())
390}
391
392/// Whether `path` refers to a SafeTensors archive (by extension).
393fn path_looks_like_safetensors(path: &Path) -> bool {
394    path.extension().and_then(|s| s.to_str()).is_some_and(|e| {
395        e.eq_ignore_ascii_case("safetensors") || e.eq_ignore_ascii_case("safetensor")
396    })
397}
398
399/// If `graph_json` contains any `ConstInit::Weights` references, load tensors from disk next to
400/// `graph_path` and replace them with [`ConstInit::InlineBytes`].
401///
402/// ## Resolution
403///
404/// 1. **No-op.** If the graph has no [`ConstInit::Weights`] initializers, return `Ok(())` without
405///    reading the filesystem.
406///
407/// 2. **Resolve weights path** (discovery is separate from loading):
408///    - If `weights_path` is set: resolve relative to the graph’s directory (or absolute as-is); the file
409///      must exist or return [`WeightResolveError::Missing`].
410///    - Else: [`discover_weights_file`] searches next to the graph in order: `{stem}.safetensors`,
411///      `{stem}.weights`, [`DEFAULT_PATH_SAFETENSORS`], [`DEFAULT_PATH_WEIGHTS`]. If none exist, return
412///      [`WeightResolveError::Missing`].
413///
414/// 3. **Load by kind:**
415///    - If the weights path is SafeTensors → [`inline_weights_from_safetensors`] and return (any
416///      `manifest_path` is ignored).
417///    - Otherwise it is a binary blob → resolve manifest: explicit `manifest_path` must exist, or
418///      [`discover_sidecar_manifest`] must find `{stem}.manifest.json` / [`DEFAULT_PATH_MANIFEST`], else
419///      [`WeightResolveError::Missing`]. Then [`inline_weights_from_manifest`].
420///
421/// Incomplete SafeTensors resolution returns [`WeightResolveError::Safetensors`]; manifest errors use
422/// [`WeightResolveError::Manifest`] / [`WeightResolveError::ManifestJson`].
423pub fn resolve_external_weights(
424    graph_json: &mut GraphJson,
425    graph_path: &Path,
426    weights_path: Option<&str>,
427    manifest_path: Option<&str>,
428) -> Result<(), WeightResolveError> {
429    eprintln!(
430        "[webnn graph] resolve external weights: graph={}, weights_path={}, manifest_path={}",
431        graph_path.display(),
432        weights_path.unwrap_or("<discover next to graph>"),
433        manifest_path.unwrap_or("<discover next to graph>"),
434    );
435
436    if !graph_has_external_weight_refs(graph_json) {
437        return Ok(());
438    }
439
440    let stem = graph_path
441        .file_stem()
442        .and_then(|s| s.to_str())
443        .unwrap_or_default();
444
445    let wp = if let Some(s) = weights_path {
446        let p = resolve_path_relative_to_graph(graph_path, s);
447        if !p.exists() {
448            return Err(WeightResolveError::Missing(format!(
449                "weights path `{}` does not exist",
450                p.display()
451            )));
452        }
453        p
454    } else {
455        discover_weights_file(graph_path).ok_or_else(|| {
456            WeightResolveError::Missing(format!(
457                "no weights file found next to `{0}`; expected `{1}.safetensors`, `{1}.weights`, \
458                 `{DEFAULT_PATH_SAFETENSORS}`, or `{DEFAULT_PATH_WEIGHTS}`, or pass `weights_path`",
459                graph_path.display(),
460                stem,
461            ))
462        })?
463    };
464
465    if path_looks_like_safetensors(&wp) {
466        return inline_weights_from_safetensors(graph_json, &wp);
467    }
468
469    let mp = if let Some(s) = manifest_path {
470        let p = resolve_path_relative_to_graph(graph_path, s);
471        if !p.exists() {
472            return Err(WeightResolveError::Missing(format!(
473                "manifest path `{}` does not exist",
474                p.display()
475            )));
476        }
477        p
478    } else {
479        discover_sidecar_manifest(graph_path).ok_or_else(|| {
480            WeightResolveError::Missing(format!(
481                "weights blob `{0}` requires a manifest; pass `manifest_path` or place `{1}.manifest.json` / \
482                 `{DEFAULT_PATH_MANIFEST}` next to `{2}`",
483                wp.display(),
484                stem,
485                graph_path.display()
486            ))
487        })?
488    };
489
490    inline_weights_from_manifest(graph_json, &mp, &wp)
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use safetensors::tensor::TensorView;
497    use safetensors::{serialize, Dtype};
498    use tempfile::TempDir;
499
500    fn write_safetensors_f32(path: &Path, tensor_name: &str, shape: Vec<usize>, data: &[u8]) {
501        let view = TensorView::new(Dtype::F32, shape, data).unwrap();
502        let bytes = serialize(vec![(tensor_name.to_string(), view)], None).unwrap();
503        std::fs::write(path, bytes).unwrap();
504    }
505
506    fn write_safetensors_bf16(path: &Path, tensor_name: &str, shape: Vec<usize>, data: &[u8]) {
507        let view = TensorView::new(Dtype::BF16, shape, data).unwrap();
508        let bytes = serialize(vec![(tensor_name.to_string(), view)], None).unwrap();
509        std::fs::write(path, bytes).unwrap();
510    }
511
512    #[test]
513    fn manifest_and_weights_inline() {
514        let temp_dir = TempDir::new().unwrap();
515        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
516        let manifest_path = temp_dir.path().join("model.manifest.json");
517        let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
518
519        let graph_content = r#"{
520            "format": "webnn-graph-json",
521            "version": 1,
522            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
523            "consts": {
524                "weight": {
525                    "dataType": "float32",
526                    "shape": [2],
527                    "init": { "kind": "weights", "ref": "weight" }
528                }
529            },
530            "nodes": [],
531            "outputs": { "y": "x" }
532        }"#;
533
534        let manifest_content = r#"{
535            "format": "webnn-weights-manifest",
536            "version": 1,
537            "endianness": "little",
538            "tensors": {
539                "weight": {
540                    "dataType": "float32",
541                    "shape": [2],
542                    "byteOffset": 0,
543                    "byteLength": 8
544                }
545            }
546        }"#;
547
548        let weights_data: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
549        std::fs::write(&graph_path, graph_content).unwrap();
550        std::fs::write(&manifest_path, manifest_content).unwrap();
551        std::fs::write(&weights_path, &weights_data).unwrap();
552
553        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
554        resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
555        match &graph.consts["weight"].init {
556            ConstInit::InlineBytes { bytes } => assert_eq!(bytes.len(), 8),
557            other => panic!("expected inline bytes, got {:?}", other),
558        }
559    }
560
561    #[test]
562    fn explicit_manifest_and_weights_paths() {
563        let temp_dir = TempDir::new().unwrap();
564        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
565        let manifest_path = temp_dir.path().join("custom.manifest.json");
566        let weights_path = temp_dir.path().join("blob.weights");
567
568        let graph_content = r#"{
569            "format": "webnn-graph-json",
570            "version": 1,
571            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
572            "consts": {
573                "weight": {
574                    "dataType": "float32",
575                    "shape": [2],
576                    "init": { "kind": "weights", "ref": "weight" }
577                }
578            },
579            "nodes": [],
580            "outputs": { "y": "x" }
581        }"#;
582
583        let manifest_content = r#"{
584            "format": "webnn-weights-manifest",
585            "version": 1,
586            "endianness": "little",
587            "tensors": {
588                "weight": {
589                    "dataType": "float32",
590                    "shape": [2],
591                    "byteOffset": 0,
592                    "byteLength": 8
593                }
594            }
595        }"#;
596
597        let weights_data: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
598        std::fs::write(&graph_path, graph_content).unwrap();
599        std::fs::write(&manifest_path, manifest_content).unwrap();
600        std::fs::write(&weights_path, &weights_data).unwrap();
601
602        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
603        resolve_external_weights(
604            &mut graph,
605            &graph_path,
606            Some("blob.weights"),
607            Some("custom.manifest.json"),
608        )
609        .unwrap();
610        match &graph.consts["weight"].init {
611            ConstInit::InlineBytes { bytes } => assert_eq!(bytes.len(), 8),
612            other => panic!("expected inline bytes, got {:?}", other),
613        }
614    }
615
616    #[test]
617    fn explicit_safetensors_weights_path() {
618        let temp_dir = TempDir::new().unwrap();
619        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
620        let st_path = temp_dir.path().join("custom.safetensors");
621
622        let graph_content = r#"{
623            "format": "webnn-graph-json",
624            "version": 1,
625            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
626            "consts": {
627                "weight": {
628                    "dataType": "float32",
629                    "shape": [2],
630                    "init": { "kind": "weights", "ref": "weight" }
631                }
632            },
633            "nodes": [],
634            "outputs": { "y": "x" }
635        }"#;
636
637        let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
638        std::fs::write(&graph_path, graph_content).unwrap();
639        write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
640
641        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
642        resolve_external_weights(&mut graph, &graph_path, Some("custom.safetensors"), None)
643            .unwrap();
644        match &graph.consts["weight"].init {
645            ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
646            other => panic!("expected inline bytes, got {:?}", other),
647        }
648    }
649
650    #[test]
651    fn manifest_arg_ignored_when_weights_path_is_safetensors() {
652        let temp_dir = TempDir::new().unwrap();
653        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
654        let st_path = temp_dir.path().join("weights.safetensors");
655
656        let graph_content = r#"{
657            "format": "webnn-graph-json",
658            "version": 1,
659            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
660            "consts": {
661                "weight": {
662                    "dataType": "float32",
663                    "shape": [2],
664                    "init": { "kind": "weights", "ref": "weight" }
665                }
666            },
667            "nodes": [],
668            "outputs": { "y": "x" }
669        }"#;
670
671        let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
672        std::fs::write(&graph_path, graph_content).unwrap();
673        write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
674
675        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
676        resolve_external_weights(
677            &mut graph,
678            &graph_path,
679            Some("weights.safetensors"),
680            Some("this_manifest_is_not_read.json"),
681        )
682        .unwrap();
683        match &graph.consts["weight"].init {
684            ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
685            other => panic!("expected inline bytes, got {:?}", other),
686        }
687    }
688
689    #[test]
690    fn safetensors_inline() {
691        let temp_dir = TempDir::new().unwrap();
692        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
693        let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
694
695        let graph_content = r#"{
696            "format": "webnn-graph-json",
697            "version": 1,
698            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
699            "consts": {
700                "weight": {
701                    "dataType": "float32",
702                    "shape": [2],
703                    "init": { "kind": "weights", "ref": "weight" }
704                }
705            },
706            "nodes": [],
707            "outputs": { "y": "x" }
708        }"#;
709
710        let tensor_bytes: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40];
711        std::fs::write(&graph_path, graph_content).unwrap();
712        write_safetensors_f32(&st_path, "weight", vec![2], &tensor_bytes);
713
714        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
715        resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
716        match &graph.consts["weight"].init {
717            ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &tensor_bytes),
718            other => panic!("expected inline bytes, got {:?}", other),
719        }
720    }
721
722    #[test]
723    fn out_of_bounds_manifest_errors() {
724        let temp_dir = TempDir::new().unwrap();
725        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
726        let manifest_path = temp_dir.path().join(DEFAULT_PATH_MANIFEST);
727        let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
728
729        let graph_content = r#"{
730            "format": "webnn-graph-json",
731            "version": 1,
732            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
733            "consts": {
734                "weight": {
735                    "dataType": "float32",
736                    "shape": [2],
737                    "init": { "kind": "weights", "ref": "weight" }
738                }
739            },
740            "nodes": [],
741            "outputs": { "y": "x" }
742        }"#;
743
744        let manifest_content = r#"{
745            "format": "webnn-weights-manifest",
746            "version": 1,
747            "tensors": {
748                "weight": {
749                    "dataType": "float32",
750                    "shape": [2],
751                    "byteOffset": 0,
752                    "byteLength": 100
753                }
754            }
755        }"#;
756
757        std::fs::write(&graph_path, graph_content).unwrap();
758        std::fs::write(&manifest_path, manifest_content).unwrap();
759        std::fs::write(&weights_path, vec![0u8; 8]).unwrap();
760
761        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
762        let err = resolve_external_weights(&mut graph, &graph_path, None, None).unwrap_err();
763        assert!(matches!(err, WeightResolveError::Manifest(_)));
764    }
765
766    #[test]
767    fn safetensors_preferred_over_invalid_manifest() {
768        let temp_dir = TempDir::new().unwrap();
769        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
770        let manifest_path = temp_dir.path().join(DEFAULT_PATH_MANIFEST);
771        let weights_path = temp_dir.path().join(DEFAULT_PATH_WEIGHTS);
772        let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
773
774        let graph_content = r#"{
775            "format": "webnn-graph-json",
776            "version": 1,
777            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
778            "consts": {
779                "weight": {
780                    "dataType": "float32",
781                    "shape": [2],
782                    "init": { "kind": "weights", "ref": "weight" }
783                }
784            },
785            "nodes": [],
786            "outputs": { "y": "x" }
787        }"#;
788
789        std::fs::write(&graph_path, graph_content).unwrap();
790        std::fs::write(&manifest_path, "{ not valid manifest json").unwrap();
791        std::fs::write(&weights_path, [0u8; 8]).unwrap();
792        write_safetensors_f32(
793            &st_path,
794            "weight",
795            vec![2],
796            &[0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40],
797        );
798
799        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
800        resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
801    }
802
803    #[test]
804    fn safetensors_bf16_converts_to_float32_for_graph_constants() {
805        use half::bf16;
806
807        let temp_dir = TempDir::new().unwrap();
808        let graph_path = temp_dir.path().join(DEFAULT_PATH_JSON);
809        let st_path = temp_dir.path().join(DEFAULT_PATH_SAFETENSORS);
810
811        let graph_content = r#"{
812            "format": "webnn-graph-json",
813            "version": 1,
814            "inputs": { "x": { "dataType": "float32", "shape": [2] } },
815            "consts": {
816                "weight": {
817                    "dataType": "float32",
818                    "shape": [2],
819                    "init": { "kind": "weights", "ref": "weight" }
820                }
821            },
822            "nodes": [],
823            "outputs": { "y": "x" }
824        }"#;
825
826        let mut bf16_bytes = Vec::new();
827        bf16_bytes.extend_from_slice(&bf16::from_f32(1.0f32).to_bits().to_le_bytes());
828        bf16_bytes.extend_from_slice(&bf16::from_f32(2.0f32).to_bits().to_le_bytes());
829
830        std::fs::write(&graph_path, graph_content).unwrap();
831        write_safetensors_bf16(&st_path, "weight", vec![2], &bf16_bytes);
832
833        let mut graph: GraphJson = serde_json::from_str(graph_content).unwrap();
834        resolve_external_weights(&mut graph, &graph_path, None, None).unwrap();
835
836        let expected: Vec<u8> = [1.0f32, 2.0f32]
837            .iter()
838            .flat_map(|f| f.to_le_bytes())
839            .collect();
840        match &graph.consts["weight"].init {
841            ConstInit::InlineBytes { bytes } => assert_eq!(bytes, &expected),
842            other => panic!("expected inline bytes, got {:?}", other),
843        }
844    }
845}