1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4use crate::ast::DataType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct WeightsManifest {
8 pub format: String, pub version: u32, pub endianness: String,
11 pub tensors: BTreeMap<String, TensorEntry>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TensorEntry {
16 #[serde(rename = "dataType")]
17 pub data_type: DataType,
18 pub shape: Vec<u32>,
19 #[serde(rename = "byteOffset")]
20 pub byte_offset: u64,
21 #[serde(rename = "byteLength")]
22 pub byte_length: u64,
23 #[serde(default)]
24 pub layout: Option<String>,
25}
26
27pub fn dtype_size(dt: &DataType) -> u64 {
28 match dt {
29 DataType::Float32 => 4,
30 DataType::Float16 => 2,
31 DataType::Int32 => 4,
32 DataType::Uint32 => 4,
33 DataType::Int64 => 8,
34 DataType::Uint64 => 8,
35 DataType::Int8 => 1,
36 DataType::Uint8 => 1,
37 }
38}
39
40pub fn numel(shape: &[u32]) -> u64 {
41 shape
42 .iter()
43 .fold(1u64, |acc, &d| acc.saturating_mul(d as u64))
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49
50 #[test]
51 fn test_dtype_size() {
52 assert_eq!(dtype_size(&DataType::Float32), 4);
53 assert_eq!(dtype_size(&DataType::Float16), 2);
54 assert_eq!(dtype_size(&DataType::Int32), 4);
55 assert_eq!(dtype_size(&DataType::Uint32), 4);
56 assert_eq!(dtype_size(&DataType::Int64), 8);
57 assert_eq!(dtype_size(&DataType::Uint64), 8);
58 assert_eq!(dtype_size(&DataType::Int8), 1);
59 assert_eq!(dtype_size(&DataType::Uint8), 1);
60 }
61
62 #[test]
63 fn test_numel() {
64 assert_eq!(numel(&[]), 1);
65 assert_eq!(numel(&[10]), 10);
66 assert_eq!(numel(&[2, 3]), 6);
67 assert_eq!(numel(&[2, 3, 4]), 24);
68 assert_eq!(numel(&[1, 2048]), 2048);
69 assert_eq!(numel(&[2048, 1000]), 2048000);
70 }
71
72 #[test]
73 fn test_numel_large_values() {
74 let large_shape = vec![u32::MAX, u32::MAX];
78 let result = numel(&large_shape);
79 assert_eq!(result, 18446744065119617025u64);
82
83 let very_large = vec![u32::MAX; 10];
85 let result2 = numel(&very_large);
86 assert_eq!(result2, u64::MAX);
88 }
89}