1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4use crate::ast::DataType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct WeightsManifest {
8 pub format: String, pub version: u32, pub endianness: String,
11 pub tensors: BTreeMap<String, TensorEntry>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TensorEntry {
16 #[serde(rename = "dataType")]
17 pub data_type: DataType,
18 pub shape: Vec<u32>,
19 #[serde(rename = "byteOffset")]
20 pub byte_offset: u64,
21 #[serde(rename = "byteLength")]
22 pub byte_length: u64,
23 #[serde(default)]
24 pub layout: Option<String>,
25}
26
27pub fn dtype_size(dt: &DataType) -> u64 {
28 match dt {
29 DataType::Float32 => 4,
30 DataType::Float16 => 2,
31 DataType::Int4 | DataType::Uint4 => 1, DataType::Int32 => 4,
33 DataType::Uint32 => 4,
34 DataType::Int64 => 8,
35 DataType::Uint64 => 8,
36 DataType::Int8 => 1,
37 DataType::Uint8 => 1,
38 }
39}
40
41pub fn numel(shape: &[u32]) -> u64 {
42 shape
43 .iter()
44 .fold(1u64, |acc, &d| acc.saturating_mul(d as u64))
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[test]
52 fn test_dtype_size() {
53 assert_eq!(dtype_size(&DataType::Float32), 4);
54 assert_eq!(dtype_size(&DataType::Float16), 2);
55 assert_eq!(dtype_size(&DataType::Int4), 1);
56 assert_eq!(dtype_size(&DataType::Uint4), 1);
57 assert_eq!(dtype_size(&DataType::Int32), 4);
58 assert_eq!(dtype_size(&DataType::Uint32), 4);
59 assert_eq!(dtype_size(&DataType::Int64), 8);
60 assert_eq!(dtype_size(&DataType::Uint64), 8);
61 assert_eq!(dtype_size(&DataType::Int8), 1);
62 assert_eq!(dtype_size(&DataType::Uint8), 1);
63 }
64
65 #[test]
66 fn test_numel() {
67 assert_eq!(numel(&[]), 1);
68 assert_eq!(numel(&[10]), 10);
69 assert_eq!(numel(&[2, 3]), 6);
70 assert_eq!(numel(&[2, 3, 4]), 24);
71 assert_eq!(numel(&[1, 2048]), 2048);
72 assert_eq!(numel(&[2048, 1000]), 2048000);
73 }
74
75 #[test]
76 fn test_numel_large_values() {
77 let large_shape = vec![u32::MAX, u32::MAX];
81 let result = numel(&large_shape);
82 assert_eq!(result, 18446744065119617025u64);
85
86 let very_large = vec![u32::MAX; 10];
88 let result2 = numel(&very_large);
89 assert_eq!(result2, u64::MAX);
91 }
92}