webnn_graph/
weights.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4use crate::ast::DataType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct WeightsManifest {
8    pub format: String, // "wg-weights-manifest"
9    pub version: u32,   // 1
10    pub endianness: String,
11    pub tensors: BTreeMap<String, TensorEntry>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TensorEntry {
16    #[serde(rename = "dataType")]
17    pub data_type: DataType,
18    pub shape: Vec<u32>,
19    #[serde(rename = "byteOffset")]
20    pub byte_offset: u64,
21    #[serde(rename = "byteLength")]
22    pub byte_length: u64,
23    #[serde(default)]
24    pub layout: Option<String>,
25}
26
27pub fn dtype_size(dt: &DataType) -> u64 {
28    match dt {
29        DataType::Float32 => 4,
30        DataType::Float16 => 2,
31        DataType::Int32 => 4,
32        DataType::Uint32 => 4,
33        DataType::Int64 => 8,
34        DataType::Uint64 => 8,
35        DataType::Int8 => 1,
36        DataType::Uint8 => 1,
37    }
38}
39
40pub fn numel(shape: &[u32]) -> u64 {
41    shape
42        .iter()
43        .fold(1u64, |acc, &d| acc.saturating_mul(d as u64))
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[test]
51    fn test_dtype_size() {
52        assert_eq!(dtype_size(&DataType::Float32), 4);
53        assert_eq!(dtype_size(&DataType::Float16), 2);
54        assert_eq!(dtype_size(&DataType::Int32), 4);
55        assert_eq!(dtype_size(&DataType::Uint32), 4);
56        assert_eq!(dtype_size(&DataType::Int64), 8);
57        assert_eq!(dtype_size(&DataType::Uint64), 8);
58        assert_eq!(dtype_size(&DataType::Int8), 1);
59        assert_eq!(dtype_size(&DataType::Uint8), 1);
60    }
61
62    #[test]
63    fn test_numel() {
64        assert_eq!(numel(&[]), 1);
65        assert_eq!(numel(&[10]), 10);
66        assert_eq!(numel(&[2, 3]), 6);
67        assert_eq!(numel(&[2, 3, 4]), 24);
68        assert_eq!(numel(&[1, 2048]), 2048);
69        assert_eq!(numel(&[2048, 1000]), 2048000);
70    }
71
72    #[test]
73    fn test_numel_large_values() {
74        // Test with large values - saturating_mul ensures no panic on overflow
75        // u32::MAX is 4294967295, and multiplying multiple of these as u64
76        // should handle large numbers correctly
77        let large_shape = vec![u32::MAX, u32::MAX];
78        let result = numel(&large_shape);
79        // The actual result is u32::MAX * u32::MAX = 18446744065119617025
80        // This fits in u64, so saturating_mul doesn't trigger saturation
81        assert_eq!(result, 18446744065119617025u64);
82
83        // Test with even larger shape that would trigger saturation
84        let very_large = vec![u32::MAX; 10];
85        let result2 = numel(&very_large);
86        // With 10 multiplications of u32::MAX, we will hit saturation
87        assert_eq!(result2, u64::MAX);
88    }
89}