tch_serde/
lib.rs

1//! Serialize/Deserialize [tch] types with [serde].
2//!
3//! The serializing and deserializing methods are groupped in `serde_tensor`,
4//! `serde_kind` and other similar modules. You can annotate `#[serde(with = "tch_serde::serde_tensor")]`
5//! attributes on fields to enable serialization.
6//!
7//! The snipplet serializes a compound type of [Tensor], [Kind] and [Device].
8//!
9//! ``` rust
10//! use tch::{Device, Kind, Reduction, Tensor};
11//!
12//! #[derive(Debug, serde::Serialize, serde::Deserialize)]
13//! struct Example {
14//!     #[serde(with = "tch_serde::serde_tensor")]
15//!     tensor: Tensor,
16//!     #[serde(with = "tch_serde::serde_kind")]
17//!     kind: Kind,
18//!     #[serde(with = "tch_serde::serde_device")]
19//!     device: Device,
20//!     #[serde(with = "tch_serde::serde_reduction")]
21//!     reduction: Reduction,
22//! }
23//!
24//! let example = Example {
25//!     tensor: Tensor::randn(&[2, 3], (Kind::Float, Device::Cuda(0))),
26//!     kind: Kind::Float,
27//!     device: Device::Cpu,
28//!     reduction: Reduction::Mean,
29//! };
30//! let text = serde_json::to_string_pretty(&example).unwrap();
31//! println!("{}", text);
32//! ```
33//!
34//! For example, it produces the following JSON text.
35//! ```json
36//! {
37//!   "tensor": {
38//!     "requires_grad": false,
39//!     "device": "cuda:0",
40//!     "shape": [
41//!       2,
42//!       3
43//!     ],
44//!     "kind": "float",
45//!     "data": [
46//!       182,
47//!       59,
48//!       207,
49//!       190,
50//!       12,
51//!       195,
52//!       95,
53//!       62,
54//!       123,
55//!       68,
56//!       200,
57//!       191,
58//!       242,
59//!       98,
60//!       231,
61//!       190,
62//!       108,
63//!       94,
64//!       225,
65//!       62,
66//!       56,
67//!       45,
68//!       3,
69//!       190
70//!     ]
71//!   },
72//!   "kind": "float",
73//!   "device": "cpu",
74//!   "reduction": "mean",
75//! }
76//! ```
77
78use half::f16;
79use serde::{
80    de::Error as DeserializeError, ser::Error as SerializeError, Deserialize, Deserializer,
81    Serialize, Serializer,
82};
83use std::{borrow::Cow, mem};
84use tch::{Device, Kind, Reduction, Tensor};
85
86/// The serialized representation of [Tensor].
87///
88/// The  [Tensor] is converted to this type during serialization.
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub struct TensorRepr {
91    pub requires_grad: bool,
92    #[serde(with = "serde_device")]
93    pub device: Device,
94    pub shape: Vec<i64>,
95    #[serde(with = "serde_kind")]
96    pub kind: Kind,
97    pub data: Vec<u8>,
98}
99
100/// Serializing/Deserializing functions for [Tensor].
101pub mod serde_tensor {
102    use super::*;
103
104    pub fn serialize<S>(tensor: &Tensor, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: Serializer,
107    {
108        let device = tensor.device();
109        let requires_grad = tensor.requires_grad();
110        let shape = tensor.size();
111        let kind = tensor.kind();
112
113        let data = {
114            let numel = tensor.numel();
115            let elem_size = match kind {
116                Kind::Uint8 => mem::size_of::<u8>(),
117                Kind::Int8 => mem::size_of::<i8>(),
118                Kind::Int16 => mem::size_of::<i16>(),
119                Kind::Int => mem::size_of::<i32>(),
120                Kind::Int64 => mem::size_of::<i64>(),
121                Kind::Half => mem::size_of::<f16>(),
122                Kind::Float => mem::size_of::<f32>(),
123                Kind::Double => mem::size_of::<f64>(),
124                Kind::Bool => mem::size_of::<bool>(),
125                Kind::QInt8 => mem::size_of::<i8>(),
126                Kind::QUInt8 => mem::size_of::<u8>(),
127                Kind::QInt32 => mem::size_of::<i32>(),
128                Kind::BFloat16 => mem::size_of::<f16>(),
129                _ => {
130                    return Err(S::Error::custom(format!(
131                        "tensor with kind {:?} is not supported yet",
132                        kind
133                    )));
134                }
135            };
136            let buf_size = numel * elem_size;
137            let mut buffer = vec![0u8; buf_size];
138            tensor.copy_data_u8(&mut buffer, numel);
139            buffer
140        };
141
142        let repr = TensorRepr {
143            requires_grad,
144            device,
145            shape,
146            kind,
147            data,
148        };
149
150        repr.serialize(serializer)
151    }
152
153    pub fn deserialize<'de, D>(deserializer: D) -> Result<Tensor, D::Error>
154    where
155        D: Deserializer<'de>,
156    {
157        let TensorRepr {
158            requires_grad,
159            device,
160            shape,
161            kind,
162            data,
163        } = Deserialize::deserialize(deserializer)?;
164
165        let tensor = Tensor::of_data_size(&data, &shape, kind);
166        let tensor = tensor.set_requires_grad(requires_grad);
167        let tensor = tensor.to_device(device);
168
169        Ok(tensor)
170    }
171}
172
173/// Serializing/Deserializing functions for [Device].
174pub mod serde_device {
175    use super::*;
176
177    pub fn serialize<S>(device: &Device, serializer: S) -> Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180    {
181        let text = match device {
182            Device::Cpu => "cpu".into(),
183            Device::Cuda(n) => format!("cuda:{}", n),
184        };
185        serializer.serialize_str(&text)
186    }
187
188    pub fn deserialize<'de, D>(deserializer: D) -> Result<Device, D::Error>
189    where
190        D: Deserializer<'de>,
191    {
192        let text = String::deserialize(deserializer)?;
193        let device = match text.as_str() {
194            "cpu" => Device::Cpu,
195            other => {
196                let index = (move || -> Option<_> {
197                    let remaining = other.strip_prefix("cuda:")?;
198                    let index: usize = remaining.parse().ok()?;
199                    Some(index)
200                })()
201                .ok_or_else(|| D::Error::custom(format!("invalid device name {}", text)))?;
202
203                Device::Cuda(index)
204            }
205        };
206
207        Ok(device)
208    }
209}
210
211/// Serializing/Deserializing functions for [Kind].
212pub mod serde_kind {
213    use super::*;
214
215    pub fn serialize<S>(kind: &Kind, serializer: S) -> Result<S::Ok, S::Error>
216    where
217        S: Serializer,
218    {
219        use Kind::*;
220        let text = match kind {
221            Uint8 => "uint8",
222            Int8 => "int8",
223            Int16 => "int16",
224            Int => "int",
225            Int64 => "int64",
226            Half => "half",
227            Float => "float",
228            Double => "double",
229            ComplexHalf => "complex_half",
230            ComplexFloat => "complex_float",
231            ComplexDouble => "complex_double",
232            Bool => "bool",
233            QInt8 => "qint8",
234            QUInt8 => "quint8",
235            QInt32 => "qint32",
236            BFloat16 => "bfloat16",
237        };
238        text.serialize(serializer)
239    }
240
241    pub fn deserialize<'de, D>(deserializer: D) -> Result<Kind, D::Error>
242    where
243        D: Deserializer<'de>,
244    {
245        use Kind::*;
246        let text = String::deserialize(deserializer)?;
247        let kind = match text.as_str() {
248            "uint8" => Uint8,
249            "int8" => Int8,
250            "int16" => Int16,
251            "int" => Int,
252            "int64" => Int64,
253            "half" => Half,
254            "float" => Float,
255            "double" => Double,
256            "complex_half" => ComplexHalf,
257            "complex_float" => ComplexFloat,
258            "complex_double" => ComplexDouble,
259            "bool" => Bool,
260            "qint8" => QInt8,
261            "quint8" => QUInt8,
262            "qint32" => QInt32,
263            "bfloat16" => BFloat16,
264            _ => return Err(D::Error::custom(format!(r#"invalid kind "{}""#, text))),
265        };
266        Ok(kind)
267    }
268}
269
270/// Serializing/Deserializing functions for [Reduction].
271pub mod serde_reduction {
272    use super::*;
273
274    pub fn serialize<S>(reduction: &Reduction, serializer: S) -> Result<S::Ok, S::Error>
275    where
276        S: Serializer,
277    {
278        let text: Cow<'_, str> = match reduction {
279            Reduction::None => "none".into(),
280            Reduction::Mean => "mean".into(),
281            Reduction::Sum => "sum".into(),
282            Reduction::Other(value) => format!("other:{}", value).into(),
283        };
284        text.serialize(serializer)
285    }
286
287    pub fn deserialize<'de, D>(deserializer: D) -> Result<Reduction, D::Error>
288    where
289        D: Deserializer<'de>,
290    {
291        let text = String::deserialize(deserializer)?;
292
293        let reduction = match &*text {
294            "none" => Reduction::None,
295            "mean" => Reduction::Mean,
296            "sum" => Reduction::Sum,
297            other => {
298                let value = (move || -> Option<i64> {
299                    let remaining = other.strip_prefix("other:")?;
300                    let value: i64 = remaining.parse().ok()?;
301                    Some(value)
302                })()
303                .ok_or_else(|| D::Error::custom(format!("invalid reduction '{}'", other)))?;
304                Reduction::Other(value)
305            }
306        };
307
308        Ok(reduction)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use anyhow::Result;
316
317    #[test]
318    fn serde_reduction_test() -> Result<()> {
319        #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
320        struct Example(#[serde(with = "serde_reduction")] Reduction);
321
322        assert_eq!(
323            serde_json::from_str::<Example>(r#""none""#)?.0,
324            Reduction::None
325        );
326        assert_eq!(
327            serde_json::from_str::<Example>(r#""mean""#)?.0,
328            Reduction::Mean
329        );
330        assert_eq!(
331            serde_json::from_str::<Example>(r#""sum""#)?.0,
332            Reduction::Sum
333        );
334        assert_eq!(
335            serde_json::from_str::<Example>(r#""other:3""#)?.0,
336            Reduction::Other(3)
337        );
338        assert_eq!(
339            serde_json::to_string(&Example(Reduction::None))?,
340            r#""none""#
341        );
342        assert_eq!(
343            serde_json::to_string(&Example(Reduction::Mean))?,
344            r#""mean""#
345        );
346        assert_eq!(serde_json::to_string(&Example(Reduction::Sum))?, r#""sum""#);
347        assert_eq!(
348            serde_json::to_string(&Example(Reduction::Other(1)))?,
349            r#""other:1""#
350        );
351
352        Ok(())
353    }
354
355    #[test]
356    fn serde_device_test() -> Result<()> {
357        #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
358        struct Example(#[serde(with = "serde_device")] Device);
359
360        // serialize
361        assert_eq!(serde_json::to_string(&Example(Device::Cpu))?, r#""cpu""#);
362        assert_eq!(
363            serde_json::to_string(&Example(Device::Cuda(0)))?,
364            r#""cuda:0""#
365        );
366        assert_eq!(
367            serde_json::to_string(&Example(Device::Cuda(1)))?,
368            r#""cuda:1""#
369        );
370
371        // deserialize
372        assert_eq!(
373            serde_json::from_str::<Example>(r#""cpu""#)?,
374            Example(Device::Cpu)
375        );
376        assert_eq!(
377            serde_json::from_str::<Example>(r#""cuda:0""#)?,
378            Example(Device::Cuda(0))
379        );
380        assert_eq!(
381            serde_json::from_str::<Example>(r#""cuda:1""#)?,
382            Example(Device::Cuda(1))
383        );
384
385        Ok(())
386    }
387
388    #[test]
389    fn serde_kind_test() -> Result<()> {
390        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
391        struct Example(#[serde(with = "serde_kind")] Kind);
392
393        // serialize
394        assert_eq!(serde_json::to_string(&Example(Kind::Int))?, r#""int""#);
395        assert_eq!(serde_json::to_string(&Example(Kind::Float))?, r#""float""#);
396        assert_eq!(serde_json::to_string(&Example(Kind::Uint8))?, r#""uint8""#);
397        assert_eq!(serde_json::to_string(&Example(Kind::Int8))?, r#""int8""#);
398        assert_eq!(serde_json::to_string(&Example(Kind::Int16))?, r#""int16""#);
399        assert_eq!(serde_json::to_string(&Example(Kind::Int))?, r#""int""#);
400        assert_eq!(serde_json::to_string(&Example(Kind::Int64))?, r#""int64""#);
401        assert_eq!(serde_json::to_string(&Example(Kind::Half))?, r#""half""#);
402        assert_eq!(serde_json::to_string(&Example(Kind::Float))?, r#""float""#);
403        assert_eq!(
404            serde_json::to_string(&Example(Kind::Double))?,
405            r#""double""#
406        );
407        assert_eq!(
408            serde_json::to_string(&Example(Kind::ComplexHalf))?,
409            r#""complex_half""#
410        );
411        assert_eq!(
412            serde_json::to_string(&Example(Kind::ComplexFloat))?,
413            r#""complex_float""#
414        );
415        assert_eq!(
416            serde_json::to_string(&Example(Kind::ComplexDouble))?,
417            r#""complex_double""#
418        );
419        assert_eq!(serde_json::to_string(&Example(Kind::Bool))?, r#""bool""#);
420        assert_eq!(serde_json::to_string(&Example(Kind::QInt8))?, r#""qint8""#);
421        assert_eq!(
422            serde_json::to_string(&Example(Kind::QUInt8))?,
423            r#""quint8""#
424        );
425        assert_eq!(
426            serde_json::to_string(&Example(Kind::QInt32))?,
427            r#""qint32""#
428        );
429        assert_eq!(
430            serde_json::to_string(&Example(Kind::BFloat16))?,
431            r#""bfloat16""#
432        );
433
434        // deserialize
435        assert_eq!(
436            serde_json::from_str::<Example>(r#""int""#)?,
437            Example(Kind::Int)
438        );
439        assert_eq!(
440            serde_json::from_str::<Example>(r#""float""#)?,
441            Example(Kind::Float)
442        );
443        assert_eq!(
444            serde_json::from_str::<Example>(r#""uint8""#)?,
445            Example(Kind::Uint8)
446        );
447        assert_eq!(
448            serde_json::from_str::<Example>(r#""int8""#)?,
449            Example(Kind::Int8)
450        );
451        assert_eq!(
452            serde_json::from_str::<Example>(r#""int16""#)?,
453            Example(Kind::Int16)
454        );
455        assert_eq!(
456            serde_json::from_str::<Example>(r#""int""#)?,
457            Example(Kind::Int)
458        );
459        assert_eq!(
460            serde_json::from_str::<Example>(r#""int64""#)?,
461            Example(Kind::Int64)
462        );
463        assert_eq!(
464            serde_json::from_str::<Example>(r#""half""#)?,
465            Example(Kind::Half)
466        );
467        assert_eq!(
468            serde_json::from_str::<Example>(r#""float""#)?,
469            Example(Kind::Float)
470        );
471        assert_eq!(
472            serde_json::from_str::<Example>(r#""double""#)?,
473            Example(Kind::Double)
474        );
475        assert_eq!(
476            serde_json::from_str::<Example>(r#""complex_half""#)?,
477            Example(Kind::ComplexHalf)
478        );
479        assert_eq!(
480            serde_json::from_str::<Example>(r#""complex_float""#)?,
481            Example(Kind::ComplexFloat)
482        );
483        assert_eq!(
484            serde_json::from_str::<Example>(r#""complex_double""#)?,
485            Example(Kind::ComplexDouble)
486        );
487        assert_eq!(
488            serde_json::from_str::<Example>(r#""bool""#)?,
489            Example(Kind::Bool)
490        );
491        assert_eq!(
492            serde_json::from_str::<Example>(r#""qint8""#)?,
493            Example(Kind::QInt8)
494        );
495        assert_eq!(
496            serde_json::from_str::<Example>(r#""quint8""#)?,
497            Example(Kind::QUInt8)
498        );
499        assert_eq!(
500            serde_json::from_str::<Example>(r#""qint32""#)?,
501            Example(Kind::QInt32)
502        );
503        assert_eq!(
504            serde_json::from_str::<Example>(r#""bfloat16""#)?,
505            Example(Kind::BFloat16)
506        );
507
508        Ok(())
509    }
510
511    #[test]
512    fn serde_tensor() -> Result<()> {
513        #[derive(Debug, Serialize, Deserialize)]
514        struct Example(#[serde(with = "serde_tensor")] Tensor);
515
516        for _ in 0..100 {
517            let orig = Example(Tensor::randn(
518                &[3, 2, 4],
519                (Kind::Float, Device::cuda_if_available()),
520            ));
521            let text = serde_json::to_string(&orig)?;
522            let recovered = serde_json::from_str(&text)?;
523
524            let Example(orig_tensor) = orig;
525            let Example(recovered_tensor) = recovered;
526
527            assert_eq!(orig_tensor.size(), recovered_tensor.size());
528            assert_eq!(orig_tensor.kind(), recovered_tensor.kind());
529            assert_eq!(orig_tensor, recovered_tensor);
530        }
531
532        for _ in 0..100 {
533            let orig = Example(Tensor::randint(
534                1024,
535                &[3, 2, 4],
536                (Kind::Float, Device::cuda_if_available()),
537            ));
538            let text = serde_json::to_string(&orig)?;
539            let recovered = serde_json::from_str(&text)?;
540
541            let Example(orig_tensor) = orig;
542            let Example(recovered_tensor) = recovered;
543
544            assert_eq!(orig_tensor.size(), recovered_tensor.size());
545            assert_eq!(orig_tensor.kind(), recovered_tensor.kind());
546            assert_eq!(orig_tensor, recovered_tensor);
547        }
548
549        Ok(())
550    }
551}