1use half::f16;
79use serde::{
80 de::Error as DeserializeError, ser::Error as SerializeError, Deserialize, Deserializer,
81 Serialize, Serializer,
82};
83use std::{borrow::Cow, mem};
84use tch::{Device, Kind, Reduction, Tensor};
85
86#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub struct TensorRepr {
91 pub requires_grad: bool,
92 #[serde(with = "serde_device")]
93 pub device: Device,
94 pub shape: Vec<i64>,
95 #[serde(with = "serde_kind")]
96 pub kind: Kind,
97 pub data: Vec<u8>,
98}
99
100pub mod serde_tensor {
102 use super::*;
103
104 pub fn serialize<S>(tensor: &Tensor, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: Serializer,
107 {
108 let device = tensor.device();
109 let requires_grad = tensor.requires_grad();
110 let shape = tensor.size();
111 let kind = tensor.kind();
112
113 let data = {
114 let numel = tensor.numel();
115 let elem_size = match kind {
116 Kind::Uint8 => mem::size_of::<u8>(),
117 Kind::Int8 => mem::size_of::<i8>(),
118 Kind::Int16 => mem::size_of::<i16>(),
119 Kind::Int => mem::size_of::<i32>(),
120 Kind::Int64 => mem::size_of::<i64>(),
121 Kind::Half => mem::size_of::<f16>(),
122 Kind::Float => mem::size_of::<f32>(),
123 Kind::Double => mem::size_of::<f64>(),
124 Kind::Bool => mem::size_of::<bool>(),
125 Kind::QInt8 => mem::size_of::<i8>(),
126 Kind::QUInt8 => mem::size_of::<u8>(),
127 Kind::QInt32 => mem::size_of::<i32>(),
128 Kind::BFloat16 => mem::size_of::<f16>(),
129 _ => {
130 return Err(S::Error::custom(format!(
131 "tensor with kind {:?} is not supported yet",
132 kind
133 )));
134 }
135 };
136 let buf_size = numel * elem_size;
137 let mut buffer = vec![0u8; buf_size];
138 tensor.copy_data_u8(&mut buffer, numel);
139 buffer
140 };
141
142 let repr = TensorRepr {
143 requires_grad,
144 device,
145 shape,
146 kind,
147 data,
148 };
149
150 repr.serialize(serializer)
151 }
152
153 pub fn deserialize<'de, D>(deserializer: D) -> Result<Tensor, D::Error>
154 where
155 D: Deserializer<'de>,
156 {
157 let TensorRepr {
158 requires_grad,
159 device,
160 shape,
161 kind,
162 data,
163 } = Deserialize::deserialize(deserializer)?;
164
165 let tensor = Tensor::of_data_size(&data, &shape, kind);
166 let tensor = tensor.set_requires_grad(requires_grad);
167 let tensor = tensor.to_device(device);
168
169 Ok(tensor)
170 }
171}
172
173pub mod serde_device {
175 use super::*;
176
177 pub fn serialize<S>(device: &Device, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: Serializer,
180 {
181 let text = match device {
182 Device::Cpu => "cpu".into(),
183 Device::Cuda(n) => format!("cuda:{}", n),
184 };
185 serializer.serialize_str(&text)
186 }
187
188 pub fn deserialize<'de, D>(deserializer: D) -> Result<Device, D::Error>
189 where
190 D: Deserializer<'de>,
191 {
192 let text = String::deserialize(deserializer)?;
193 let device = match text.as_str() {
194 "cpu" => Device::Cpu,
195 other => {
196 let index = (move || -> Option<_> {
197 let remaining = other.strip_prefix("cuda:")?;
198 let index: usize = remaining.parse().ok()?;
199 Some(index)
200 })()
201 .ok_or_else(|| D::Error::custom(format!("invalid device name {}", text)))?;
202
203 Device::Cuda(index)
204 }
205 };
206
207 Ok(device)
208 }
209}
210
211pub mod serde_kind {
213 use super::*;
214
215 pub fn serialize<S>(kind: &Kind, serializer: S) -> Result<S::Ok, S::Error>
216 where
217 S: Serializer,
218 {
219 use Kind::*;
220 let text = match kind {
221 Uint8 => "uint8",
222 Int8 => "int8",
223 Int16 => "int16",
224 Int => "int",
225 Int64 => "int64",
226 Half => "half",
227 Float => "float",
228 Double => "double",
229 ComplexHalf => "complex_half",
230 ComplexFloat => "complex_float",
231 ComplexDouble => "complex_double",
232 Bool => "bool",
233 QInt8 => "qint8",
234 QUInt8 => "quint8",
235 QInt32 => "qint32",
236 BFloat16 => "bfloat16",
237 };
238 text.serialize(serializer)
239 }
240
241 pub fn deserialize<'de, D>(deserializer: D) -> Result<Kind, D::Error>
242 where
243 D: Deserializer<'de>,
244 {
245 use Kind::*;
246 let text = String::deserialize(deserializer)?;
247 let kind = match text.as_str() {
248 "uint8" => Uint8,
249 "int8" => Int8,
250 "int16" => Int16,
251 "int" => Int,
252 "int64" => Int64,
253 "half" => Half,
254 "float" => Float,
255 "double" => Double,
256 "complex_half" => ComplexHalf,
257 "complex_float" => ComplexFloat,
258 "complex_double" => ComplexDouble,
259 "bool" => Bool,
260 "qint8" => QInt8,
261 "quint8" => QUInt8,
262 "qint32" => QInt32,
263 "bfloat16" => BFloat16,
264 _ => return Err(D::Error::custom(format!(r#"invalid kind "{}""#, text))),
265 };
266 Ok(kind)
267 }
268}
269
270pub mod serde_reduction {
272 use super::*;
273
274 pub fn serialize<S>(reduction: &Reduction, serializer: S) -> Result<S::Ok, S::Error>
275 where
276 S: Serializer,
277 {
278 let text: Cow<'_, str> = match reduction {
279 Reduction::None => "none".into(),
280 Reduction::Mean => "mean".into(),
281 Reduction::Sum => "sum".into(),
282 Reduction::Other(value) => format!("other:{}", value).into(),
283 };
284 text.serialize(serializer)
285 }
286
287 pub fn deserialize<'de, D>(deserializer: D) -> Result<Reduction, D::Error>
288 where
289 D: Deserializer<'de>,
290 {
291 let text = String::deserialize(deserializer)?;
292
293 let reduction = match &*text {
294 "none" => Reduction::None,
295 "mean" => Reduction::Mean,
296 "sum" => Reduction::Sum,
297 other => {
298 let value = (move || -> Option<i64> {
299 let remaining = other.strip_prefix("other:")?;
300 let value: i64 = remaining.parse().ok()?;
301 Some(value)
302 })()
303 .ok_or_else(|| D::Error::custom(format!("invalid reduction '{}'", other)))?;
304 Reduction::Other(value)
305 }
306 };
307
308 Ok(reduction)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use anyhow::Result;
316
317 #[test]
318 fn serde_reduction_test() -> Result<()> {
319 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
320 struct Example(#[serde(with = "serde_reduction")] Reduction);
321
322 assert_eq!(
323 serde_json::from_str::<Example>(r#""none""#)?.0,
324 Reduction::None
325 );
326 assert_eq!(
327 serde_json::from_str::<Example>(r#""mean""#)?.0,
328 Reduction::Mean
329 );
330 assert_eq!(
331 serde_json::from_str::<Example>(r#""sum""#)?.0,
332 Reduction::Sum
333 );
334 assert_eq!(
335 serde_json::from_str::<Example>(r#""other:3""#)?.0,
336 Reduction::Other(3)
337 );
338 assert_eq!(
339 serde_json::to_string(&Example(Reduction::None))?,
340 r#""none""#
341 );
342 assert_eq!(
343 serde_json::to_string(&Example(Reduction::Mean))?,
344 r#""mean""#
345 );
346 assert_eq!(serde_json::to_string(&Example(Reduction::Sum))?, r#""sum""#);
347 assert_eq!(
348 serde_json::to_string(&Example(Reduction::Other(1)))?,
349 r#""other:1""#
350 );
351
352 Ok(())
353 }
354
355 #[test]
356 fn serde_device_test() -> Result<()> {
357 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
358 struct Example(#[serde(with = "serde_device")] Device);
359
360 assert_eq!(serde_json::to_string(&Example(Device::Cpu))?, r#""cpu""#);
362 assert_eq!(
363 serde_json::to_string(&Example(Device::Cuda(0)))?,
364 r#""cuda:0""#
365 );
366 assert_eq!(
367 serde_json::to_string(&Example(Device::Cuda(1)))?,
368 r#""cuda:1""#
369 );
370
371 assert_eq!(
373 serde_json::from_str::<Example>(r#""cpu""#)?,
374 Example(Device::Cpu)
375 );
376 assert_eq!(
377 serde_json::from_str::<Example>(r#""cuda:0""#)?,
378 Example(Device::Cuda(0))
379 );
380 assert_eq!(
381 serde_json::from_str::<Example>(r#""cuda:1""#)?,
382 Example(Device::Cuda(1))
383 );
384
385 Ok(())
386 }
387
388 #[test]
389 fn serde_kind_test() -> Result<()> {
390 #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
391 struct Example(#[serde(with = "serde_kind")] Kind);
392
393 assert_eq!(serde_json::to_string(&Example(Kind::Int))?, r#""int""#);
395 assert_eq!(serde_json::to_string(&Example(Kind::Float))?, r#""float""#);
396 assert_eq!(serde_json::to_string(&Example(Kind::Uint8))?, r#""uint8""#);
397 assert_eq!(serde_json::to_string(&Example(Kind::Int8))?, r#""int8""#);
398 assert_eq!(serde_json::to_string(&Example(Kind::Int16))?, r#""int16""#);
399 assert_eq!(serde_json::to_string(&Example(Kind::Int))?, r#""int""#);
400 assert_eq!(serde_json::to_string(&Example(Kind::Int64))?, r#""int64""#);
401 assert_eq!(serde_json::to_string(&Example(Kind::Half))?, r#""half""#);
402 assert_eq!(serde_json::to_string(&Example(Kind::Float))?, r#""float""#);
403 assert_eq!(
404 serde_json::to_string(&Example(Kind::Double))?,
405 r#""double""#
406 );
407 assert_eq!(
408 serde_json::to_string(&Example(Kind::ComplexHalf))?,
409 r#""complex_half""#
410 );
411 assert_eq!(
412 serde_json::to_string(&Example(Kind::ComplexFloat))?,
413 r#""complex_float""#
414 );
415 assert_eq!(
416 serde_json::to_string(&Example(Kind::ComplexDouble))?,
417 r#""complex_double""#
418 );
419 assert_eq!(serde_json::to_string(&Example(Kind::Bool))?, r#""bool""#);
420 assert_eq!(serde_json::to_string(&Example(Kind::QInt8))?, r#""qint8""#);
421 assert_eq!(
422 serde_json::to_string(&Example(Kind::QUInt8))?,
423 r#""quint8""#
424 );
425 assert_eq!(
426 serde_json::to_string(&Example(Kind::QInt32))?,
427 r#""qint32""#
428 );
429 assert_eq!(
430 serde_json::to_string(&Example(Kind::BFloat16))?,
431 r#""bfloat16""#
432 );
433
434 assert_eq!(
436 serde_json::from_str::<Example>(r#""int""#)?,
437 Example(Kind::Int)
438 );
439 assert_eq!(
440 serde_json::from_str::<Example>(r#""float""#)?,
441 Example(Kind::Float)
442 );
443 assert_eq!(
444 serde_json::from_str::<Example>(r#""uint8""#)?,
445 Example(Kind::Uint8)
446 );
447 assert_eq!(
448 serde_json::from_str::<Example>(r#""int8""#)?,
449 Example(Kind::Int8)
450 );
451 assert_eq!(
452 serde_json::from_str::<Example>(r#""int16""#)?,
453 Example(Kind::Int16)
454 );
455 assert_eq!(
456 serde_json::from_str::<Example>(r#""int""#)?,
457 Example(Kind::Int)
458 );
459 assert_eq!(
460 serde_json::from_str::<Example>(r#""int64""#)?,
461 Example(Kind::Int64)
462 );
463 assert_eq!(
464 serde_json::from_str::<Example>(r#""half""#)?,
465 Example(Kind::Half)
466 );
467 assert_eq!(
468 serde_json::from_str::<Example>(r#""float""#)?,
469 Example(Kind::Float)
470 );
471 assert_eq!(
472 serde_json::from_str::<Example>(r#""double""#)?,
473 Example(Kind::Double)
474 );
475 assert_eq!(
476 serde_json::from_str::<Example>(r#""complex_half""#)?,
477 Example(Kind::ComplexHalf)
478 );
479 assert_eq!(
480 serde_json::from_str::<Example>(r#""complex_float""#)?,
481 Example(Kind::ComplexFloat)
482 );
483 assert_eq!(
484 serde_json::from_str::<Example>(r#""complex_double""#)?,
485 Example(Kind::ComplexDouble)
486 );
487 assert_eq!(
488 serde_json::from_str::<Example>(r#""bool""#)?,
489 Example(Kind::Bool)
490 );
491 assert_eq!(
492 serde_json::from_str::<Example>(r#""qint8""#)?,
493 Example(Kind::QInt8)
494 );
495 assert_eq!(
496 serde_json::from_str::<Example>(r#""quint8""#)?,
497 Example(Kind::QUInt8)
498 );
499 assert_eq!(
500 serde_json::from_str::<Example>(r#""qint32""#)?,
501 Example(Kind::QInt32)
502 );
503 assert_eq!(
504 serde_json::from_str::<Example>(r#""bfloat16""#)?,
505 Example(Kind::BFloat16)
506 );
507
508 Ok(())
509 }
510
511 #[test]
512 fn serde_tensor() -> Result<()> {
513 #[derive(Debug, Serialize, Deserialize)]
514 struct Example(#[serde(with = "serde_tensor")] Tensor);
515
516 for _ in 0..100 {
517 let orig = Example(Tensor::randn(
518 &[3, 2, 4],
519 (Kind::Float, Device::cuda_if_available()),
520 ));
521 let text = serde_json::to_string(&orig)?;
522 let recovered = serde_json::from_str(&text)?;
523
524 let Example(orig_tensor) = orig;
525 let Example(recovered_tensor) = recovered;
526
527 assert_eq!(orig_tensor.size(), recovered_tensor.size());
528 assert_eq!(orig_tensor.kind(), recovered_tensor.kind());
529 assert_eq!(orig_tensor, recovered_tensor);
530 }
531
532 for _ in 0..100 {
533 let orig = Example(Tensor::randint(
534 1024,
535 &[3, 2, 4],
536 (Kind::Float, Device::cuda_if_available()),
537 ));
538 let text = serde_json::to_string(&orig)?;
539 let recovered = serde_json::from_str(&text)?;
540
541 let Example(orig_tensor) = orig;
542 let Example(recovered_tensor) = recovered;
543
544 assert_eq!(orig_tensor.size(), recovered_tensor.size());
545 assert_eq!(orig_tensor.kind(), recovered_tensor.kind());
546 assert_eq!(orig_tensor, recovered_tensor);
547 }
548
549 Ok(())
550 }
551}