serde_save/
imp.rs

1use crate::{Error, Save, Variant};
2use core::{cmp, convert::Infallible, fmt, marker::PhantomData};
3use std::collections::BTreeSet;
4
5mod sealed {
6    pub trait Sealed {}
7    impl Sealed for super::ShortCircuit {}
8    impl Sealed for super::Persist {}
9}
10
11pub trait ErrorDiscipline: sealed::Sealed {
12    type SaveError;
13    fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error>;
14}
15
16pub enum ShortCircuit {}
17pub enum Persist {}
18
19impl ErrorDiscipline for ShortCircuit {
20    type SaveError = Infallible;
21    fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error> {
22        res
23    }
24}
25
26impl ErrorDiscipline for Persist {
27    type SaveError = Error;
28    fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error> {
29        Ok(res.unwrap_or_else(Save::Error))
30    }
31}
32
33/// Serializer which produces [`Save`]s.
34///
35/// See [crate documentation](mod@super) for more.
36pub struct Serializer<ErrorDiscipline = ShortCircuit> {
37    config: Config<ErrorDiscipline>,
38}
39
40impl Serializer<ShortCircuit> {
41    /// Create a serializer which is:
42    /// - [human readable](`serde::Serializer::is_human_readable`) (this is the default for serde formats).
43    /// - NOT sensitive to [protocol errors](Self::check_for_protocol_errors).
44    pub fn new() -> Self {
45        Self {
46            config: Config {
47                is_human_readable: true,
48                protocol_errors: false,
49                _error_discipline: PhantomData,
50            },
51        }
52    }
53}
54
55impl<E> Serializer<E> {
56    /// See [`serde::Serializer::is_human_readable`].
57    pub fn human_readable(mut self, is_human_readable: bool) -> Self {
58        self.config.is_human_readable = is_human_readable;
59        self
60    }
61    /// Whether to check for incorrect implementations of e.g [`serde::ser::SerializeSeq`].
62    /// See documentation on variants of [`Save`] for the invariants which are checked.
63    pub fn check_for_protocol_errors(mut self, check: bool) -> Self {
64        self.config.protocol_errors = check;
65        self
66    }
67    /// Persist the errors in-tree.
68    ///
69    /// If any node's implementation of [`serde::Serialize::serialize`] fails, it
70    /// will be recorded as a [`Save::Error`].
71    ///
72    /// If there are any [protocol errors](Self::check_for_protocol_errors), they
73    /// will be recorded as the final element(s) of the corresponding collection.
74    pub fn save_errors(self) -> Serializer<Persist> {
75        let Self {
76            config:
77                Config {
78                    is_human_readable,
79                    protocol_errors,
80                    _error_discipline,
81                },
82        } = self;
83        Serializer {
84            config: Config {
85                is_human_readable,
86                protocol_errors,
87                _error_discipline: PhantomData,
88            },
89        }
90    }
91}
92
93impl Default for Serializer {
94    /// See [`Self::new`].
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100struct Config<E = ShortCircuit> {
101    is_human_readable: bool,
102    protocol_errors: bool,
103    _error_discipline: PhantomData<fn() -> E>,
104}
105
106impl<E> Clone for Config<E> {
107    fn clone(&self) -> Self {
108        *self
109    }
110}
111impl<E> Copy for Config<E> {}
112
113macro_rules! simple {
114    ($($method:ident($ty:ty) -> $variant:ident);* $(;)?) => {
115        $(
116            fn $method(self, v: $ty) -> Result<Self::Ok, Self::Error> {
117                Ok(Save::$variant(v))
118            }
119        )*
120    };
121}
122
123impl<E> serde::Serializer for Serializer<E>
124where
125    E: ErrorDiscipline,
126{
127    type Ok = Save<'static, E::SaveError>;
128    type Error = Error;
129    type SerializeSeq = SerializeSeq<E>;
130    type SerializeTuple = SerializeTuple<E>;
131    type SerializeTupleStruct = SerializeTupleStruct<E>;
132    type SerializeTupleVariant = SerializeTupleVariant<E>;
133    type SerializeMap = SerializeMap<E>;
134    type SerializeStruct = SerializeStruct<E>;
135    type SerializeStructVariant = SerializeStructVariant<E>;
136
137    fn is_human_readable(&self) -> bool {
138        self.config.is_human_readable
139    }
140
141    simple! {
142        serialize_bool(bool) -> Bool;
143        serialize_i8(i8) -> I8;
144        serialize_i16(i16) -> I16;
145        serialize_i32(i32) -> I32;
146        serialize_i64(i64) -> I64;
147        serialize_u8(u8) -> U8;
148        serialize_u16(u16) -> U16;
149        serialize_u32(u32) -> U32;
150        serialize_u64(u64) -> U64;
151        serialize_f32(f32) -> F32;
152        serialize_f64(f64) -> F64;
153        serialize_char(char) -> Char;
154    }
155
156    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
157        Ok(Save::String(v.into()))
158    }
159    fn collect_str<T: ?Sized + fmt::Display>(self, value: &T) -> Result<Self::Ok, Self::Error> {
160        Ok(Save::String(value.to_string()))
161    }
162    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
163        Ok(Save::ByteArray(v.into()))
164    }
165    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
166        Ok(Save::Option(None))
167    }
168    fn serialize_some<T: ?Sized + serde::Serialize>(
169        self,
170        value: &T,
171    ) -> Result<Self::Ok, Self::Error> {
172        Ok(Save::Option(Some(Box::new(E::handle(
173            value.serialize(self),
174        )?))))
175    }
176    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
177        Ok(Save::Unit)
178    }
179    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
180        Ok(Save::UnitStruct(name))
181    }
182    fn serialize_unit_variant(
183        self,
184        name: &'static str,
185        variant_index: u32,
186        variant: &'static str,
187    ) -> Result<Self::Ok, Self::Error> {
188        Ok(Save::UnitVariant(Variant {
189            name,
190            variant_index,
191            variant,
192        }))
193    }
194    fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
195        self,
196        name: &'static str,
197        value: &T,
198    ) -> Result<Self::Ok, Self::Error> {
199        Ok(Save::NewTypeStruct {
200            name,
201            value: Box::new(E::handle(value.serialize(self))?),
202        })
203    }
204    fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
205        self,
206        name: &'static str,
207        variant_index: u32,
208        variant: &'static str,
209        value: &T,
210    ) -> Result<Self::Ok, Self::Error> {
211        Ok(Save::NewTypeVariant {
212            variant: Variant {
213                name,
214                variant_index,
215                variant,
216            },
217            value: Box::new(E::handle(value.serialize(self))?),
218        })
219    }
220    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
221        Ok(SerializeSeq {
222            config: self.config,
223            inner: Vec::with_capacity(len.unwrap_or_default()),
224            expected_len: len,
225        })
226    }
227    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
228        Ok(SerializeTuple {
229            config: self.config,
230            inner: Vec::with_capacity(len),
231            expected_len: len,
232        })
233    }
234    fn serialize_tuple_struct(
235        self,
236        name: &'static str,
237        len: usize,
238    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
239        Ok(SerializeTupleStruct {
240            expected_len: len,
241            config: self.config,
242            name,
243            values: Vec::with_capacity(len),
244        })
245    }
246    fn serialize_tuple_variant(
247        self,
248        name: &'static str,
249        variant_index: u32,
250        variant: &'static str,
251        len: usize,
252    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
253        Ok(SerializeTupleVariant {
254            expected_len: len,
255            config: self.config,
256            variant: Variant {
257                name,
258                variant_index,
259                variant,
260            },
261            values: Vec::with_capacity(len),
262        })
263    }
264    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
265        let capacity = len.unwrap_or_default();
266        Ok(SerializeMap {
267            config: self.config,
268            expected_len: len,
269            keys: Vec::with_capacity(capacity),
270            values: Vec::with_capacity(capacity),
271        })
272    }
273    fn serialize_struct(
274        self,
275        name: &'static str,
276        len: usize,
277    ) -> Result<Self::SerializeStruct, Self::Error> {
278        Ok(SerializeStruct {
279            expected_len: len,
280            config: self.config,
281            name,
282            fields: Vec::with_capacity(len),
283        })
284    }
285    fn serialize_struct_variant(
286        self,
287        name: &'static str,
288        variant_index: u32,
289        variant: &'static str,
290        len: usize,
291    ) -> Result<Self::SerializeStructVariant, Self::Error> {
292        Ok(SerializeStructVariant {
293            config: self.config,
294            variant: Variant {
295                name,
296                variant_index,
297                variant,
298            },
299            fields: Vec::with_capacity(len),
300            expected_len: len,
301        })
302    }
303}
304
305fn check_length<E>(
306    what: &str,
307    config: &Config<E>,
308    expected: usize,
309    pushing: &mut Vec<Save<'static, E::SaveError>>,
310) -> Result<(), Error>
311where
312    E: ErrorDiscipline,
313{
314    if config.protocol_errors {
315        let actual = pushing.len();
316        if expected != actual {
317            let e = Error {
318                msg: format!(
319                    "protocol error: expected a {} of length {}, got {}",
320                    what, expected, actual
321                ),
322                protocol: true,
323            };
324            pushing.push(E::handle(Err(e))?)
325        }
326    }
327    Ok(())
328}
329
330pub struct SerializeSeq<E: ErrorDiscipline> {
331    config: Config<E>,
332    expected_len: Option<usize>,
333    inner: Vec<Save<'static, E::SaveError>>,
334}
335impl<E> serde::ser::SerializeSeq for SerializeSeq<E>
336where
337    E: ErrorDiscipline,
338{
339    type Ok = Save<'static, E::SaveError>;
340    type Error = Error;
341    fn serialize_element<T: ?Sized + serde::Serialize>(
342        &mut self,
343        value: &T,
344    ) -> Result<(), Self::Error> {
345        self.inner.push(E::handle(value.serialize(Serializer {
346            config: self.config,
347        }))?);
348        Ok(())
349    }
350    fn end(mut self) -> Result<Self::Ok, Self::Error> {
351        if let Some(expected_len) = self.expected_len {
352            check_length("sequence", &self.config, expected_len, &mut self.inner)?;
353        }
354        Ok(Save::Seq(self.inner))
355    }
356}
357pub struct SerializeTuple<E: ErrorDiscipline> {
358    expected_len: usize,
359    config: Config<E>,
360    inner: Vec<Save<'static, E::SaveError>>,
361}
362impl<E> serde::ser::SerializeTuple for SerializeTuple<E>
363where
364    E: ErrorDiscipline,
365{
366    type Ok = Save<'static, E::SaveError>;
367    type Error = Error;
368    fn serialize_element<T: ?Sized + serde::Serialize>(
369        &mut self,
370        value: &T,
371    ) -> Result<(), Self::Error> {
372        self.inner.push(E::handle(value.serialize(Serializer {
373            config: self.config,
374        }))?);
375        Ok(())
376    }
377    fn end(mut self) -> Result<Self::Ok, Self::Error> {
378        check_length("tuple", &self.config, self.expected_len, &mut self.inner)?;
379        Ok(Save::Tuple(self.inner))
380    }
381}
382pub struct SerializeTupleStruct<E: ErrorDiscipline> {
383    expected_len: usize,
384    config: Config<E>,
385    name: &'static str,
386    values: Vec<Save<'static, E::SaveError>>,
387}
388impl<E> serde::ser::SerializeTupleStruct for SerializeTupleStruct<E>
389where
390    E: ErrorDiscipline,
391{
392    type Ok = Save<'static, E::SaveError>;
393    type Error = Error;
394    fn serialize_field<T: ?Sized + serde::Serialize>(
395        &mut self,
396        value: &T,
397    ) -> Result<(), Self::Error> {
398        self.values.push(E::handle(value.serialize(Serializer {
399            config: self.config,
400        }))?);
401        Ok(())
402    }
403
404    fn end(mut self) -> Result<Self::Ok, Self::Error> {
405        check_length(
406            "tuple struct",
407            &self.config,
408            self.expected_len,
409            &mut self.values,
410        )?;
411        Ok(Save::TupleStruct {
412            name: self.name,
413            values: self.values,
414        })
415    }
416}
417pub struct SerializeTupleVariant<E: ErrorDiscipline> {
418    expected_len: usize,
419    config: Config<E>,
420    variant: Variant<'static>,
421    values: Vec<Save<'static, E::SaveError>>,
422}
423impl<E> serde::ser::SerializeTupleVariant for SerializeTupleVariant<E>
424where
425    E: ErrorDiscipline,
426{
427    type Ok = Save<'static, E::SaveError>;
428    type Error = Error;
429    fn serialize_field<T: ?Sized + serde::Serialize>(
430        &mut self,
431        value: &T,
432    ) -> Result<(), Self::Error> {
433        self.values.push(E::handle(value.serialize(Serializer {
434            config: self.config,
435        }))?);
436        Ok(())
437    }
438    fn end(mut self) -> Result<Self::Ok, Self::Error> {
439        check_length(
440            "tuple variant",
441            &self.config,
442            self.expected_len,
443            &mut self.values,
444        )?;
445
446        Ok(Save::TupleVariant {
447            variant: self.variant,
448            values: self.values,
449        })
450    }
451}
452pub struct SerializeMap<E: ErrorDiscipline> {
453    expected_len: Option<usize>,
454    config: Config<E>,
455    keys: Vec<Save<'static, E::SaveError>>,
456    values: Vec<Save<'static, E::SaveError>>,
457}
458impl<E> serde::ser::SerializeMap for SerializeMap<E>
459where
460    E: ErrorDiscipline,
461{
462    type Ok = Save<'static, E::SaveError>;
463    type Error = Error;
464    fn serialize_key<T: ?Sized + serde::Serialize>(&mut self, key: &T) -> Result<(), Self::Error> {
465        self.keys.push(E::handle(key.serialize(Serializer {
466            config: self.config,
467        }))?);
468        Ok(())
469    }
470    fn serialize_value<T: ?Sized + serde::Serialize>(
471        &mut self,
472        value: &T,
473    ) -> Result<(), Self::Error> {
474        self.values.push(E::handle(value.serialize(Serializer {
475            config: self.config,
476        }))?);
477        Ok(())
478    }
479    fn end(self) -> Result<Self::Ok, Self::Error> {
480        let n_keys = self.keys.len();
481        let n_values = self.values.len();
482        let mut map = Vec::with_capacity(cmp::max(n_keys, n_values));
483        let mut keys = self.keys.into_iter();
484        let mut values = self.values.into_iter();
485        loop {
486            let e = || Error {
487                msg: format!(
488                    "protocol error: map has {} keys and {} values",
489                    n_keys, n_values
490                ),
491                protocol: true,
492            };
493            match (keys.next(), values.next()) {
494                (None, None) => {
495                    if let Some(expected) = self.expected_len {
496                        if self.config.protocol_errors && expected != map.len() {
497                            let e = || Error {
498                                msg: format!(
499                                    "protocol error: expected a map of length {}, got {}",
500                                    expected,
501                                    map.len()
502                                ),
503                                protocol: true,
504                            };
505                            map.push((E::handle(Err(e()))?, E::handle(Err(e()))?))
506                        }
507                    }
508                    return Ok(Save::Map(map));
509                }
510                (Some(key), Some(value)) => map.push((key, value)),
511                (None, Some(value)) => map.push((E::handle(Err(e()))?, value)),
512                (Some(key), None) => map.push((key, E::handle(Err(e()))?)),
513            }
514        }
515    }
516}
517
518fn check<E>(
519    what: &str,
520    config: &Config<E>,
521    expected_len: usize,
522    fields: &mut Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
523) -> Result<(), Error>
524where
525    E: ErrorDiscipline,
526{
527    if config.protocol_errors {
528        let mut seen = BTreeSet::new();
529        let mut dups = Vec::new();
530        for name in fields.iter().map(|(it, _)| it) {
531            let new = seen.insert(*name);
532            if !new {
533                dups.push(*name)
534            }
535        }
536        if !dups.is_empty() {
537            let e = Error {
538                msg: format!(
539                    "protocol error: {} has duplicate field names: {}",
540                    what,
541                    dups.join(", ")
542                ),
543                protocol: true,
544            };
545            fields.push(("!error", Some(E::handle(Err(e))?)))
546        }
547
548        let actual = fields.len();
549        if expected_len != actual {
550            let e = Error {
551                msg: format!(
552                    "protocol error: expected a {} of length {}, got {}",
553                    what, expected_len, actual
554                ),
555                protocol: true,
556            };
557            fields.push(("!error", Some(E::handle(Err(e))?)))
558        }
559    }
560    Ok(())
561}
562
563pub struct SerializeStruct<E: ErrorDiscipline> {
564    expected_len: usize,
565    config: Config<E>,
566    name: &'static str,
567    fields: Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
568}
569impl<E> serde::ser::SerializeStruct for SerializeStruct<E>
570where
571    E: ErrorDiscipline,
572{
573    type Ok = Save<'static, E::SaveError>;
574    type Error = Error;
575    fn serialize_field<T: ?Sized + serde::Serialize>(
576        &mut self,
577        key: &'static str,
578        value: &T,
579    ) -> Result<(), Self::Error> {
580        self.fields.push((
581            key,
582            Some(E::handle(value.serialize(Serializer {
583                config: self.config,
584            }))?),
585        ));
586        Ok(())
587    }
588    fn end(mut self) -> Result<Self::Ok, Self::Error> {
589        check("struct", &self.config, self.expected_len, &mut self.fields)?;
590        Ok(Save::Struct {
591            name: self.name,
592            fields: self.fields,
593        })
594    }
595    fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
596        self.fields.push((key, None));
597        Ok(())
598    }
599}
600pub struct SerializeStructVariant<E: ErrorDiscipline> {
601    expected_len: usize,
602    config: Config<E>,
603    variant: Variant<'static>,
604    fields: Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
605}
606impl<E> serde::ser::SerializeStructVariant for SerializeStructVariant<E>
607where
608    E: ErrorDiscipline,
609{
610    type Ok = Save<'static, E::SaveError>;
611    type Error = Error;
612    fn serialize_field<T: ?Sized + serde::Serialize>(
613        &mut self,
614        key: &'static str,
615        value: &T,
616    ) -> Result<(), Self::Error> {
617        self.fields.push((
618            key,
619            Some(E::handle(value.serialize(Serializer {
620                config: self.config,
621            }))?),
622        ));
623        Ok(())
624    }
625    fn end(mut self) -> Result<Self::Ok, Self::Error> {
626        check("struct", &self.config, self.expected_len, &mut self.fields)?;
627
628        Ok(Save::StructVariant {
629            variant: self.variant,
630            fields: self.fields,
631        })
632    }
633    fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
634        self.fields.push((key, None));
635        Ok(())
636    }
637}