Skip to main content

yscv_model/
checkpoint.rs

1use serde::{Deserialize, Serialize};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// Serializable tensor snapshot used in model checkpoints.
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct TensorSnapshot {
9    pub shape: Vec<usize>,
10    pub data: Vec<f32>,
11}
12
13impl TensorSnapshot {
14    pub fn from_tensor(tensor: &Tensor) -> Self {
15        Self {
16            shape: tensor.shape().to_vec(),
17            data: tensor.data().to_vec(),
18        }
19    }
20
21    pub fn into_tensor(self) -> Result<Tensor, ModelError> {
22        Tensor::from_vec(self.shape, self.data).map_err(Into::into)
23    }
24}
25
26/// Serializable layer checkpoint payload.
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28#[serde(tag = "layer", content = "payload")]
29pub enum LayerCheckpoint {
30    Linear {
31        in_features: usize,
32        out_features: usize,
33        weight: TensorSnapshot,
34        bias: TensorSnapshot,
35    },
36    ReLU,
37    LeakyReLU {
38        negative_slope: f32,
39    },
40    Sigmoid,
41    Tanh,
42    Dropout {
43        rate: f32,
44    },
45    Conv2d {
46        in_channels: usize,
47        out_channels: usize,
48        kernel_h: usize,
49        kernel_w: usize,
50        stride_h: usize,
51        stride_w: usize,
52        weight: TensorSnapshot,
53        bias: Option<TensorSnapshot>,
54    },
55    BatchNorm2d {
56        num_features: usize,
57        epsilon: f32,
58        gamma: TensorSnapshot,
59        beta: TensorSnapshot,
60        running_mean: TensorSnapshot,
61        running_var: TensorSnapshot,
62    },
63    MaxPool2d {
64        kernel_h: usize,
65        kernel_w: usize,
66        stride_h: usize,
67        stride_w: usize,
68    },
69    AvgPool2d {
70        kernel_h: usize,
71        kernel_w: usize,
72        stride_h: usize,
73        stride_w: usize,
74    },
75    Flatten,
76    GlobalAvgPool2d,
77    Softmax,
78    Embedding {
79        num_embeddings: usize,
80        embedding_dim: usize,
81        weight: TensorSnapshot,
82    },
83    LayerNorm {
84        normalized_shape: usize,
85        eps: f32,
86        gamma: TensorSnapshot,
87        beta: TensorSnapshot,
88    },
89    GroupNorm {
90        num_groups: usize,
91        num_channels: usize,
92        eps: f32,
93        gamma: TensorSnapshot,
94        beta: TensorSnapshot,
95    },
96    DepthwiseConv2d {
97        channels: usize,
98        kernel_h: usize,
99        kernel_w: usize,
100        stride_h: usize,
101        stride_w: usize,
102        weight: TensorSnapshot,
103        bias: Option<TensorSnapshot>,
104    },
105    SeparableConv2d {
106        in_channels: usize,
107        out_channels: usize,
108        kernel_h: usize,
109        kernel_w: usize,
110        stride_h: usize,
111        stride_w: usize,
112        depthwise_weight: TensorSnapshot,
113        pointwise_weight: TensorSnapshot,
114        bias: Option<TensorSnapshot>,
115    },
116}
117
118/// Serializable sequential model checkpoint.
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct SequentialCheckpoint {
121    pub layers: Vec<LayerCheckpoint>,
122}
123
124pub fn checkpoint_to_json(checkpoint: &SequentialCheckpoint) -> Result<String, ModelError> {
125    serde_json::to_string_pretty(checkpoint).map_err(|err| ModelError::CheckpointSerialization {
126        message: err.to_string(),
127    })
128}
129
130pub fn checkpoint_from_json(json: &str) -> Result<SequentialCheckpoint, ModelError> {
131    serde_json::from_str(json).map_err(|err| ModelError::CheckpointSerialization {
132        message: err.to_string(),
133    })
134}