Skip to main content

webnn_graph/
weights.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4use crate::ast::DataType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct WeightsManifest {
8    pub format: String, // "wg-weights-manifest"
9    pub version: u32,   // 1
10    pub endianness: String,
11    pub tensors: BTreeMap<String, TensorEntry>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TensorEntry {
16    #[serde(rename = "dataType")]
17    pub data_type: DataType,
18    pub shape: Vec<u32>,
19    #[serde(rename = "byteOffset")]
20    pub byte_offset: u64,
21    #[serde(rename = "byteLength")]
22    pub byte_length: u64,
23    #[serde(default)]
24    pub layout: Option<String>,
25}
26
27pub fn dtype_size(dt: &DataType) -> u64 {
28    match dt {
29        DataType::Float32 => 4,
30        DataType::Float16 => 2,
31        DataType::Int4 | DataType::Uint4 => 1, // byte-per-value for 4-bit storage
32        DataType::Int32 => 4,
33        DataType::Uint32 => 4,
34        DataType::Int64 => 8,
35        DataType::Uint64 => 8,
36        DataType::Int8 => 1,
37        DataType::Uint8 => 1,
38    }
39}
40
41pub fn numel(shape: &[u32]) -> u64 {
42    shape
43        .iter()
44        .fold(1u64, |acc, &d| acc.saturating_mul(d as u64))
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn test_dtype_size() {
53        assert_eq!(dtype_size(&DataType::Float32), 4);
54        assert_eq!(dtype_size(&DataType::Float16), 2);
55        assert_eq!(dtype_size(&DataType::Int4), 1);
56        assert_eq!(dtype_size(&DataType::Uint4), 1);
57        assert_eq!(dtype_size(&DataType::Int32), 4);
58        assert_eq!(dtype_size(&DataType::Uint32), 4);
59        assert_eq!(dtype_size(&DataType::Int64), 8);
60        assert_eq!(dtype_size(&DataType::Uint64), 8);
61        assert_eq!(dtype_size(&DataType::Int8), 1);
62        assert_eq!(dtype_size(&DataType::Uint8), 1);
63    }
64
65    #[test]
66    fn test_numel() {
67        assert_eq!(numel(&[]), 1);
68        assert_eq!(numel(&[10]), 10);
69        assert_eq!(numel(&[2, 3]), 6);
70        assert_eq!(numel(&[2, 3, 4]), 24);
71        assert_eq!(numel(&[1, 2048]), 2048);
72        assert_eq!(numel(&[2048, 1000]), 2048000);
73    }
74
75    #[test]
76    fn test_numel_large_values() {
77        // Test with large values - saturating_mul ensures no panic on overflow
78        // u32::MAX is 4294967295, and multiplying multiple of these as u64
79        // should handle large numbers correctly
80        let large_shape = vec![u32::MAX, u32::MAX];
81        let result = numel(&large_shape);
82        // The actual result is u32::MAX * u32::MAX = 18446744065119617025
83        // This fits in u64, so saturating_mul doesn't trigger saturation
84        assert_eq!(result, 18446744065119617025u64);
85
86        // Test with even larger shape that would trigger saturation
87        let very_large = vec![u32::MAX; 10];
88        let result2 = numel(&very_large);
89        // With 10 multiplications of u32::MAX, we will hit saturation
90        assert_eq!(result2, u64::MAX);
91    }
92}