1use serde::{
2 de::{Deserialize, Visitor},
3 ser::{Serialize, SerializeStruct},
4};
5
6use crate::device::cpu::Cpu;
7use crate::device::Device;
8use crate::dim::{DimDyn, DimTrait};
9use crate::matrix::{Matrix, Owned, Ptr, Repr};
10use crate::num::Num;
11
12impl<R: Repr, D: Device> Serialize for Matrix<R, DimDyn, D> {
13 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
14 where
15 S: serde::Serializer,
16 {
17 let shape = self.shape().slice().to_vec();
18 let stride = self.stride().slice().to_vec();
19 let data = self
20 .new_matrix()
21 .clone()
22 .to::<Cpu>()
23 .reshape([self.shape().num_elm()])
24 .to_vec();
25 let data_type = std::any::type_name::<R::Item>().to_string();
26 let ptr_offset = self.offset();
27
28 let mut state = serializer.serialize_struct("Matrix", 5)?;
29
30 state.serialize_field("shape", &shape)?;
31 state.serialize_field("stride", &stride)?;
32 state.serialize_field("data", &data)?;
33 state.serialize_field("data_type", &data_type)?;
34 state.serialize_field("ptr_offset", &ptr_offset)?;
35
36 state.end()
37 }
38}
39
40impl<'de, T: Num + Deserialize<'de>, D: Device> Deserialize<'de> for Matrix<Owned<T>, DimDyn, D> {
41 #[expect(clippy::too_many_lines)]
42 fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
43 where
44 Ds: serde::Deserializer<'de>,
45 {
46 enum Field {
47 Shape,
48 Stride,
49 Data,
50 DataType,
51 PtrOffset,
52 }
53
54 const FIELDS: &[&str] = &["shape", "stride", "data", "data_type", "ptr_offset"];
55
56 impl<'de> Deserialize<'de> for Field {
57 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
58 where
59 D: serde::Deserializer<'de>,
60 {
61 struct FieldVisitor;
62
63 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
64 type Value = Field;
65
66 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
67 formatter
68 .write_str("`shape`, `stride`, `data`, `data_type` or `ptr_offset`")
69 }
70
71 fn visit_str<E>(self, value: &str) -> Result<Field, E>
72 where
73 E: serde::de::Error,
74 {
75 match value {
76 "shape" => Ok(Field::Shape),
77 "stride" => Ok(Field::Stride),
78 "data" => Ok(Field::Data),
79 "data_type" => Ok(Field::DataType),
80 "ptr_offset" => Ok(Field::PtrOffset),
81 _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
82 }
83 }
84 }
85
86 deserializer.deserialize_identifier(FieldVisitor)
87 }
88 }
89
90 struct MatrixVisitor<T: Num, D: Device>(std::marker::PhantomData<(T, D)>);
91
92 #[expect(clippy::similar_names)]
93 impl<'de, T: Num + Deserialize<'de>, D: Device> Visitor<'de> for MatrixVisitor<T, D> {
94 type Value = Matrix<Owned<T>, DimDyn, D>;
95
96 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
97 formatter.write_str("struct Matrix")
98 }
99
100 fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
101 where
102 V: serde::de::MapAccess<'de>,
103 {
104 let mut shape = None;
105 let mut stride = None;
106 let mut data = None;
107 let mut data_type = None;
108 let mut ptr_offset = None;
109
110 while let Some(key) = map.next_key()? {
111 match key {
112 Field::Shape => {
113 if shape.is_some() {
114 return Err(serde::de::Error::duplicate_field("shape"));
115 }
116 shape = Some(map.next_value()?);
117 }
118 Field::Stride => {
119 if stride.is_some() {
120 return Err(serde::de::Error::duplicate_field("stride"));
121 }
122 stride = Some(map.next_value()?);
123 }
124 Field::Data => {
125 if data.is_some() {
126 return Err(serde::de::Error::duplicate_field("data"));
127 }
128 data = Some(map.next_value()?);
129 }
130 Field::DataType => {
131 if data_type.is_some() {
132 return Err(serde::de::Error::duplicate_field("data_type"));
133 }
134 data_type = Some(map.next_value()?);
135 }
136 Field::PtrOffset => {
137 if ptr_offset.is_some() {
138 return Err(serde::de::Error::duplicate_field("ptr_offset"));
139 }
140 ptr_offset = Some(map.next_value()?);
141 }
142 }
143 }
144
145 let shape: Vec<usize> =
146 shape.ok_or_else(|| serde::de::Error::missing_field("shape"))?;
147 let stride: Vec<usize> =
148 stride.ok_or_else(|| serde::de::Error::missing_field("stride"))?;
149 let data: Vec<T> = data.ok_or_else(|| serde::de::Error::missing_field("data"))?;
150 let data_type: String =
151 data_type.ok_or_else(|| serde::de::Error::missing_field("data_type"))?;
152 let ptr_offset =
153 ptr_offset.ok_or_else(|| serde::de::Error::missing_field("ptr_offset"))?;
154
155 if std::any::type_name::<T>() != data_type {
156 return Err(serde::de::Error::custom("Data type mismatch"));
157 }
158
159 let shape = DimDyn::from(&shape as &[usize]);
160 let stride = DimDyn::from(&stride as &[usize]);
161
162 let mut data_cloned = data.clone();
163 let ptr =
164 Ptr::<Owned<T>, Cpu>::new(data_cloned.as_mut_ptr(), data.len(), ptr_offset);
165 std::mem::forget(data_cloned);
166
167 let mat = Matrix::new(ptr, shape, stride);
168 let mat_d: Matrix<Owned<T>, DimDyn, D> = mat.to();
169 Ok(mat_d)
170 }
171
172 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
173 where
174 A: serde::de::SeqAccess<'de>,
175 {
176 let shape: Vec<usize> = seq
177 .next_element()?
178 .ok_or_else(|| serde::de::Error::invalid_length(0, &""))?;
179 let stride: Vec<usize> = seq
180 .next_element()?
181 .ok_or_else(|| serde::de::Error::invalid_length(1, &""))?;
182 let data: Vec<T> = seq
183 .next_element()?
184 .ok_or_else(|| serde::de::Error::invalid_length(2, &""))?;
185 let data_type: String = seq
186 .next_element()?
187 .ok_or_else(|| serde::de::Error::invalid_length(3, &""))?;
188 let ptr_offset = seq
189 .next_element()?
190 .ok_or_else(|| serde::de::Error::invalid_length(4, &""))?;
191
192 if std::any::type_name::<T>() != data_type {
193 return Err(serde::de::Error::custom("Data type mismatch"));
194 }
195
196 let shape = DimDyn::from(&shape as &[usize]);
197 let stride = DimDyn::from(&stride as &[usize]);
198
199 let mut data_cloned = data.clone();
200 let ptr =
201 Ptr::<Owned<T>, Cpu>::new(data_cloned.as_mut_ptr(), data.len(), ptr_offset);
202 std::mem::forget(data_cloned);
203
204 let mat = Matrix::new(ptr, shape, stride);
205 let mat_d: Matrix<Owned<T>, DimDyn, D> = mat.to();
206 Ok(mat_d)
207 }
208 }
209
210 deserializer.deserialize_struct(
211 "Matrix",
212 FIELDS,
213 MatrixVisitor::<T, D>(std::marker::PhantomData),
214 )
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use zenu_test::assert_mat_eq_epsilon;
222
223 #[test]
224 fn test_matrix_serialization_deserialization() {
225 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
226 let shape: Vec<usize> = vec![2, 3];
227 let matrix: Matrix<Owned<f64>, DimDyn, Cpu> =
228 Matrix::from_vec(data, DimDyn::from(&shape as &[usize]));
229
230 let serialized = serde_json::to_string(&matrix).expect("Failed to serialize matrix");
231
232 let deserialized: Matrix<Owned<f64>, DimDyn, Cpu> =
233 serde_json::from_str(&serialized).expect("Failed to deserialize matrix");
234
235 assert_eq!(matrix.shape(), deserialized.shape());
236 assert_eq!(matrix.stride(), deserialized.stride());
237
238 let original_data = matrix.to_vec();
239 let deserialized_data = deserialized.to_vec();
240 assert_eq!(original_data, deserialized_data);
241
242 assert_mat_eq_epsilon!(matrix, deserialized, 1e-6);
243 }
244
245 #[test]
246 fn test_matrix_serialization_format() {
247 let data = vec![1., 2., 3., 4.];
248 let shape: Vec<usize> = vec![2, 2];
249 let matrix: Matrix<Owned<f32>, DimDyn, Cpu> =
250 Matrix::from_vec(data, DimDyn::from(&shape as &[usize]));
251
252 let serialized = serde_json::to_string_pretty(&matrix).expect("Failed to serialize matrix");
253
254 let expected_json = r#"{
256 "shape": [
257 2,
258 2
259 ],
260 "stride": [
261 2,
262 1
263 ],
264 "data": [
265 1.0,
266 2.0,
267 3.0,
268 4.0
269 ],
270 "data_type": "f32",
271 "ptr_offset": 0
272}"#;
273
274 assert_eq!(serialized, expected_json);
275 }
276
277 #[test]
278 fn test_matrix_deserialization_error() {
279 let invalid_json = r#"{
280 "shape": [2, 2],
281 "stride": [2, 1],
282 "data": [1.0, 2.0, 3.0, 4.0],
283 "data_type": "f64",
284 "ptr_offset": 0
285 }"#;
286
287 let result: Result<Matrix<Owned<f32>, DimDyn, Cpu>, _> = serde_json::from_str(invalid_json);
288 assert!(result.is_err());
289 }
290}