zenu_matrix/
impl_serde.rs

1use serde::{
2    de::{Deserialize, Visitor},
3    ser::{Serialize, SerializeStruct},
4};
5
6use crate::device::cpu::Cpu;
7use crate::device::Device;
8use crate::dim::{DimDyn, DimTrait};
9use crate::matrix::{Matrix, Owned, Ptr, Repr};
10use crate::num::Num;
11
12impl<R: Repr, D: Device> Serialize for Matrix<R, DimDyn, D> {
13    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
14    where
15        S: serde::Serializer,
16    {
17        let shape = self.shape().slice().to_vec();
18        let stride = self.stride().slice().to_vec();
19        let data = self
20            .new_matrix()
21            .clone()
22            .to::<Cpu>()
23            .reshape([self.shape().num_elm()])
24            .to_vec();
25        let data_type = std::any::type_name::<R::Item>().to_string();
26        let ptr_offset = self.offset();
27
28        let mut state = serializer.serialize_struct("Matrix", 5)?;
29
30        state.serialize_field("shape", &shape)?;
31        state.serialize_field("stride", &stride)?;
32        state.serialize_field("data", &data)?;
33        state.serialize_field("data_type", &data_type)?;
34        state.serialize_field("ptr_offset", &ptr_offset)?;
35
36        state.end()
37    }
38}
39
40impl<'de, T: Num + Deserialize<'de>, D: Device> Deserialize<'de> for Matrix<Owned<T>, DimDyn, D> {
41    #[expect(clippy::too_many_lines)]
42    fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
43    where
44        Ds: serde::Deserializer<'de>,
45    {
46        enum Field {
47            Shape,
48            Stride,
49            Data,
50            DataType,
51            PtrOffset,
52        }
53
54        const FIELDS: &[&str] = &["shape", "stride", "data", "data_type", "ptr_offset"];
55
56        impl<'de> Deserialize<'de> for Field {
57            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
58            where
59                D: serde::Deserializer<'de>,
60            {
61                struct FieldVisitor;
62
63                impl<'de> serde::de::Visitor<'de> for FieldVisitor {
64                    type Value = Field;
65
66                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
67                        formatter
68                            .write_str("`shape`, `stride`, `data`, `data_type` or `ptr_offset`")
69                    }
70
71                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
72                    where
73                        E: serde::de::Error,
74                    {
75                        match value {
76                            "shape" => Ok(Field::Shape),
77                            "stride" => Ok(Field::Stride),
78                            "data" => Ok(Field::Data),
79                            "data_type" => Ok(Field::DataType),
80                            "ptr_offset" => Ok(Field::PtrOffset),
81                            _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
82                        }
83                    }
84                }
85
86                deserializer.deserialize_identifier(FieldVisitor)
87            }
88        }
89
90        struct MatrixVisitor<T: Num, D: Device>(std::marker::PhantomData<(T, D)>);
91
92        #[expect(clippy::similar_names)]
93        impl<'de, T: Num + Deserialize<'de>, D: Device> Visitor<'de> for MatrixVisitor<T, D> {
94            type Value = Matrix<Owned<T>, DimDyn, D>;
95
96            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
97                formatter.write_str("struct Matrix")
98            }
99
100            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
101            where
102                V: serde::de::MapAccess<'de>,
103            {
104                let mut shape = None;
105                let mut stride = None;
106                let mut data = None;
107                let mut data_type = None;
108                let mut ptr_offset = None;
109
110                while let Some(key) = map.next_key()? {
111                    match key {
112                        Field::Shape => {
113                            if shape.is_some() {
114                                return Err(serde::de::Error::duplicate_field("shape"));
115                            }
116                            shape = Some(map.next_value()?);
117                        }
118                        Field::Stride => {
119                            if stride.is_some() {
120                                return Err(serde::de::Error::duplicate_field("stride"));
121                            }
122                            stride = Some(map.next_value()?);
123                        }
124                        Field::Data => {
125                            if data.is_some() {
126                                return Err(serde::de::Error::duplicate_field("data"));
127                            }
128                            data = Some(map.next_value()?);
129                        }
130                        Field::DataType => {
131                            if data_type.is_some() {
132                                return Err(serde::de::Error::duplicate_field("data_type"));
133                            }
134                            data_type = Some(map.next_value()?);
135                        }
136                        Field::PtrOffset => {
137                            if ptr_offset.is_some() {
138                                return Err(serde::de::Error::duplicate_field("ptr_offset"));
139                            }
140                            ptr_offset = Some(map.next_value()?);
141                        }
142                    }
143                }
144
145                let shape: Vec<usize> =
146                    shape.ok_or_else(|| serde::de::Error::missing_field("shape"))?;
147                let stride: Vec<usize> =
148                    stride.ok_or_else(|| serde::de::Error::missing_field("stride"))?;
149                let data: Vec<T> = data.ok_or_else(|| serde::de::Error::missing_field("data"))?;
150                let data_type: String =
151                    data_type.ok_or_else(|| serde::de::Error::missing_field("data_type"))?;
152                let ptr_offset =
153                    ptr_offset.ok_or_else(|| serde::de::Error::missing_field("ptr_offset"))?;
154
155                if std::any::type_name::<T>() != data_type {
156                    return Err(serde::de::Error::custom("Data type mismatch"));
157                }
158
159                let shape = DimDyn::from(&shape as &[usize]);
160                let stride = DimDyn::from(&stride as &[usize]);
161
162                let mut data_cloned = data.clone();
163                let ptr =
164                    Ptr::<Owned<T>, Cpu>::new(data_cloned.as_mut_ptr(), data.len(), ptr_offset);
165                std::mem::forget(data_cloned);
166
167                let mat = Matrix::new(ptr, shape, stride);
168                let mat_d: Matrix<Owned<T>, DimDyn, D> = mat.to();
169                Ok(mat_d)
170            }
171
172            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
173            where
174                A: serde::de::SeqAccess<'de>,
175            {
176                let shape: Vec<usize> = seq
177                    .next_element()?
178                    .ok_or_else(|| serde::de::Error::invalid_length(0, &""))?;
179                let stride: Vec<usize> = seq
180                    .next_element()?
181                    .ok_or_else(|| serde::de::Error::invalid_length(1, &""))?;
182                let data: Vec<T> = seq
183                    .next_element()?
184                    .ok_or_else(|| serde::de::Error::invalid_length(2, &""))?;
185                let data_type: String = seq
186                    .next_element()?
187                    .ok_or_else(|| serde::de::Error::invalid_length(3, &""))?;
188                let ptr_offset = seq
189                    .next_element()?
190                    .ok_or_else(|| serde::de::Error::invalid_length(4, &""))?;
191
192                if std::any::type_name::<T>() != data_type {
193                    return Err(serde::de::Error::custom("Data type mismatch"));
194                }
195
196                let shape = DimDyn::from(&shape as &[usize]);
197                let stride = DimDyn::from(&stride as &[usize]);
198
199                let mut data_cloned = data.clone();
200                let ptr =
201                    Ptr::<Owned<T>, Cpu>::new(data_cloned.as_mut_ptr(), data.len(), ptr_offset);
202                std::mem::forget(data_cloned);
203
204                let mat = Matrix::new(ptr, shape, stride);
205                let mat_d: Matrix<Owned<T>, DimDyn, D> = mat.to();
206                Ok(mat_d)
207            }
208        }
209
210        deserializer.deserialize_struct(
211            "Matrix",
212            FIELDS,
213            MatrixVisitor::<T, D>(std::marker::PhantomData),
214        )
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use zenu_test::assert_mat_eq_epsilon;
222
223    #[test]
224    fn test_matrix_serialization_deserialization() {
225        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
226        let shape: Vec<usize> = vec![2, 3];
227        let matrix: Matrix<Owned<f64>, DimDyn, Cpu> =
228            Matrix::from_vec(data, DimDyn::from(&shape as &[usize]));
229
230        let serialized = serde_json::to_string(&matrix).expect("Failed to serialize matrix");
231
232        let deserialized: Matrix<Owned<f64>, DimDyn, Cpu> =
233            serde_json::from_str(&serialized).expect("Failed to deserialize matrix");
234
235        assert_eq!(matrix.shape(), deserialized.shape());
236        assert_eq!(matrix.stride(), deserialized.stride());
237
238        let original_data = matrix.to_vec();
239        let deserialized_data = deserialized.to_vec();
240        assert_eq!(original_data, deserialized_data);
241
242        assert_mat_eq_epsilon!(matrix, deserialized, 1e-6);
243    }
244
245    #[test]
246    fn test_matrix_serialization_format() {
247        let data = vec![1., 2., 3., 4.];
248        let shape: Vec<usize> = vec![2, 2];
249        let matrix: Matrix<Owned<f32>, DimDyn, Cpu> =
250            Matrix::from_vec(data, DimDyn::from(&shape as &[usize]));
251
252        let serialized = serde_json::to_string_pretty(&matrix).expect("Failed to serialize matrix");
253
254        // 期待されるJSON形式を確認
255        let expected_json = r#"{
256  "shape": [
257    2,
258    2
259  ],
260  "stride": [
261    2,
262    1
263  ],
264  "data": [
265    1.0,
266    2.0,
267    3.0,
268    4.0
269  ],
270  "data_type": "f32",
271  "ptr_offset": 0
272}"#;
273
274        assert_eq!(serialized, expected_json);
275    }
276
277    #[test]
278    fn test_matrix_deserialization_error() {
279        let invalid_json = r#"{
280            "shape": [2, 2],
281            "stride": [2, 1],
282            "data": [1.0, 2.0, 3.0, 4.0],
283            "data_type": "f64",
284            "ptr_offset": 0
285        }"#;
286
287        let result: Result<Matrix<Owned<f32>, DimDyn, Cpu>, _> = serde_json::from_str(invalid_json);
288        assert!(result.is_err());
289    }
290}