1use crate::{Error, Save, Variant};
2use core::{cmp, convert::Infallible, fmt, marker::PhantomData};
3use std::collections::BTreeSet;
4
5mod sealed {
6 pub trait Sealed {}
7 impl Sealed for super::ShortCircuit {}
8 impl Sealed for super::Persist {}
9}
10
11pub trait ErrorDiscipline: sealed::Sealed {
12 type SaveError;
13 fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error>;
14}
15
16pub enum ShortCircuit {}
17pub enum Persist {}
18
19impl ErrorDiscipline for ShortCircuit {
20 type SaveError = Infallible;
21 fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error> {
22 res
23 }
24}
25
26impl ErrorDiscipline for Persist {
27 type SaveError = Error;
28 fn handle(res: Result<Save<Self::SaveError>, Error>) -> Result<Save<Self::SaveError>, Error> {
29 Ok(res.unwrap_or_else(Save::Error))
30 }
31}
32
33pub struct Serializer<ErrorDiscipline = ShortCircuit> {
37 config: Config<ErrorDiscipline>,
38}
39
40impl Serializer<ShortCircuit> {
41 pub fn new() -> Self {
45 Self {
46 config: Config {
47 is_human_readable: true,
48 protocol_errors: false,
49 _error_discipline: PhantomData,
50 },
51 }
52 }
53}
54
55impl<E> Serializer<E> {
56 pub fn human_readable(mut self, is_human_readable: bool) -> Self {
58 self.config.is_human_readable = is_human_readable;
59 self
60 }
61 pub fn check_for_protocol_errors(mut self, check: bool) -> Self {
64 self.config.protocol_errors = check;
65 self
66 }
67 pub fn save_errors(self) -> Serializer<Persist> {
75 let Self {
76 config:
77 Config {
78 is_human_readable,
79 protocol_errors,
80 _error_discipline,
81 },
82 } = self;
83 Serializer {
84 config: Config {
85 is_human_readable,
86 protocol_errors,
87 _error_discipline: PhantomData,
88 },
89 }
90 }
91}
92
93impl Default for Serializer {
94 fn default() -> Self {
96 Self::new()
97 }
98}
99
100struct Config<E = ShortCircuit> {
101 is_human_readable: bool,
102 protocol_errors: bool,
103 _error_discipline: PhantomData<fn() -> E>,
104}
105
106impl<E> Clone for Config<E> {
107 fn clone(&self) -> Self {
108 *self
109 }
110}
111impl<E> Copy for Config<E> {}
112
113macro_rules! simple {
114 ($($method:ident($ty:ty) -> $variant:ident);* $(;)?) => {
115 $(
116 fn $method(self, v: $ty) -> Result<Self::Ok, Self::Error> {
117 Ok(Save::$variant(v))
118 }
119 )*
120 };
121}
122
123impl<E> serde::Serializer for Serializer<E>
124where
125 E: ErrorDiscipline,
126{
127 type Ok = Save<'static, E::SaveError>;
128 type Error = Error;
129 type SerializeSeq = SerializeSeq<E>;
130 type SerializeTuple = SerializeTuple<E>;
131 type SerializeTupleStruct = SerializeTupleStruct<E>;
132 type SerializeTupleVariant = SerializeTupleVariant<E>;
133 type SerializeMap = SerializeMap<E>;
134 type SerializeStruct = SerializeStruct<E>;
135 type SerializeStructVariant = SerializeStructVariant<E>;
136
137 fn is_human_readable(&self) -> bool {
138 self.config.is_human_readable
139 }
140
141 simple! {
142 serialize_bool(bool) -> Bool;
143 serialize_i8(i8) -> I8;
144 serialize_i16(i16) -> I16;
145 serialize_i32(i32) -> I32;
146 serialize_i64(i64) -> I64;
147 serialize_u8(u8) -> U8;
148 serialize_u16(u16) -> U16;
149 serialize_u32(u32) -> U32;
150 serialize_u64(u64) -> U64;
151 serialize_f32(f32) -> F32;
152 serialize_f64(f64) -> F64;
153 serialize_char(char) -> Char;
154 }
155
156 fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
157 Ok(Save::String(v.into()))
158 }
159 fn collect_str<T: ?Sized + fmt::Display>(self, value: &T) -> Result<Self::Ok, Self::Error> {
160 Ok(Save::String(value.to_string()))
161 }
162 fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
163 Ok(Save::ByteArray(v.into()))
164 }
165 fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
166 Ok(Save::Option(None))
167 }
168 fn serialize_some<T: ?Sized + serde::Serialize>(
169 self,
170 value: &T,
171 ) -> Result<Self::Ok, Self::Error> {
172 Ok(Save::Option(Some(Box::new(E::handle(
173 value.serialize(self),
174 )?))))
175 }
176 fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
177 Ok(Save::Unit)
178 }
179 fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
180 Ok(Save::UnitStruct(name))
181 }
182 fn serialize_unit_variant(
183 self,
184 name: &'static str,
185 variant_index: u32,
186 variant: &'static str,
187 ) -> Result<Self::Ok, Self::Error> {
188 Ok(Save::UnitVariant(Variant {
189 name,
190 variant_index,
191 variant,
192 }))
193 }
194 fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
195 self,
196 name: &'static str,
197 value: &T,
198 ) -> Result<Self::Ok, Self::Error> {
199 Ok(Save::NewTypeStruct {
200 name,
201 value: Box::new(E::handle(value.serialize(self))?),
202 })
203 }
204 fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
205 self,
206 name: &'static str,
207 variant_index: u32,
208 variant: &'static str,
209 value: &T,
210 ) -> Result<Self::Ok, Self::Error> {
211 Ok(Save::NewTypeVariant {
212 variant: Variant {
213 name,
214 variant_index,
215 variant,
216 },
217 value: Box::new(E::handle(value.serialize(self))?),
218 })
219 }
220 fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
221 Ok(SerializeSeq {
222 config: self.config,
223 inner: Vec::with_capacity(len.unwrap_or_default()),
224 expected_len: len,
225 })
226 }
227 fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
228 Ok(SerializeTuple {
229 config: self.config,
230 inner: Vec::with_capacity(len),
231 expected_len: len,
232 })
233 }
234 fn serialize_tuple_struct(
235 self,
236 name: &'static str,
237 len: usize,
238 ) -> Result<Self::SerializeTupleStruct, Self::Error> {
239 Ok(SerializeTupleStruct {
240 expected_len: len,
241 config: self.config,
242 name,
243 values: Vec::with_capacity(len),
244 })
245 }
246 fn serialize_tuple_variant(
247 self,
248 name: &'static str,
249 variant_index: u32,
250 variant: &'static str,
251 len: usize,
252 ) -> Result<Self::SerializeTupleVariant, Self::Error> {
253 Ok(SerializeTupleVariant {
254 expected_len: len,
255 config: self.config,
256 variant: Variant {
257 name,
258 variant_index,
259 variant,
260 },
261 values: Vec::with_capacity(len),
262 })
263 }
264 fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
265 let capacity = len.unwrap_or_default();
266 Ok(SerializeMap {
267 config: self.config,
268 expected_len: len,
269 keys: Vec::with_capacity(capacity),
270 values: Vec::with_capacity(capacity),
271 })
272 }
273 fn serialize_struct(
274 self,
275 name: &'static str,
276 len: usize,
277 ) -> Result<Self::SerializeStruct, Self::Error> {
278 Ok(SerializeStruct {
279 expected_len: len,
280 config: self.config,
281 name,
282 fields: Vec::with_capacity(len),
283 })
284 }
285 fn serialize_struct_variant(
286 self,
287 name: &'static str,
288 variant_index: u32,
289 variant: &'static str,
290 len: usize,
291 ) -> Result<Self::SerializeStructVariant, Self::Error> {
292 Ok(SerializeStructVariant {
293 config: self.config,
294 variant: Variant {
295 name,
296 variant_index,
297 variant,
298 },
299 fields: Vec::with_capacity(len),
300 expected_len: len,
301 })
302 }
303}
304
305fn check_length<E>(
306 what: &str,
307 config: &Config<E>,
308 expected: usize,
309 pushing: &mut Vec<Save<'static, E::SaveError>>,
310) -> Result<(), Error>
311where
312 E: ErrorDiscipline,
313{
314 if config.protocol_errors {
315 let actual = pushing.len();
316 if expected != actual {
317 let e = Error {
318 msg: format!(
319 "protocol error: expected a {} of length {}, got {}",
320 what, expected, actual
321 ),
322 protocol: true,
323 };
324 pushing.push(E::handle(Err(e))?)
325 }
326 }
327 Ok(())
328}
329
330pub struct SerializeSeq<E: ErrorDiscipline> {
331 config: Config<E>,
332 expected_len: Option<usize>,
333 inner: Vec<Save<'static, E::SaveError>>,
334}
335impl<E> serde::ser::SerializeSeq for SerializeSeq<E>
336where
337 E: ErrorDiscipline,
338{
339 type Ok = Save<'static, E::SaveError>;
340 type Error = Error;
341 fn serialize_element<T: ?Sized + serde::Serialize>(
342 &mut self,
343 value: &T,
344 ) -> Result<(), Self::Error> {
345 self.inner.push(E::handle(value.serialize(Serializer {
346 config: self.config,
347 }))?);
348 Ok(())
349 }
350 fn end(mut self) -> Result<Self::Ok, Self::Error> {
351 if let Some(expected_len) = self.expected_len {
352 check_length("sequence", &self.config, expected_len, &mut self.inner)?;
353 }
354 Ok(Save::Seq(self.inner))
355 }
356}
357pub struct SerializeTuple<E: ErrorDiscipline> {
358 expected_len: usize,
359 config: Config<E>,
360 inner: Vec<Save<'static, E::SaveError>>,
361}
362impl<E> serde::ser::SerializeTuple for SerializeTuple<E>
363where
364 E: ErrorDiscipline,
365{
366 type Ok = Save<'static, E::SaveError>;
367 type Error = Error;
368 fn serialize_element<T: ?Sized + serde::Serialize>(
369 &mut self,
370 value: &T,
371 ) -> Result<(), Self::Error> {
372 self.inner.push(E::handle(value.serialize(Serializer {
373 config: self.config,
374 }))?);
375 Ok(())
376 }
377 fn end(mut self) -> Result<Self::Ok, Self::Error> {
378 check_length("tuple", &self.config, self.expected_len, &mut self.inner)?;
379 Ok(Save::Tuple(self.inner))
380 }
381}
382pub struct SerializeTupleStruct<E: ErrorDiscipline> {
383 expected_len: usize,
384 config: Config<E>,
385 name: &'static str,
386 values: Vec<Save<'static, E::SaveError>>,
387}
388impl<E> serde::ser::SerializeTupleStruct for SerializeTupleStruct<E>
389where
390 E: ErrorDiscipline,
391{
392 type Ok = Save<'static, E::SaveError>;
393 type Error = Error;
394 fn serialize_field<T: ?Sized + serde::Serialize>(
395 &mut self,
396 value: &T,
397 ) -> Result<(), Self::Error> {
398 self.values.push(E::handle(value.serialize(Serializer {
399 config: self.config,
400 }))?);
401 Ok(())
402 }
403
404 fn end(mut self) -> Result<Self::Ok, Self::Error> {
405 check_length(
406 "tuple struct",
407 &self.config,
408 self.expected_len,
409 &mut self.values,
410 )?;
411 Ok(Save::TupleStruct {
412 name: self.name,
413 values: self.values,
414 })
415 }
416}
417pub struct SerializeTupleVariant<E: ErrorDiscipline> {
418 expected_len: usize,
419 config: Config<E>,
420 variant: Variant<'static>,
421 values: Vec<Save<'static, E::SaveError>>,
422}
423impl<E> serde::ser::SerializeTupleVariant for SerializeTupleVariant<E>
424where
425 E: ErrorDiscipline,
426{
427 type Ok = Save<'static, E::SaveError>;
428 type Error = Error;
429 fn serialize_field<T: ?Sized + serde::Serialize>(
430 &mut self,
431 value: &T,
432 ) -> Result<(), Self::Error> {
433 self.values.push(E::handle(value.serialize(Serializer {
434 config: self.config,
435 }))?);
436 Ok(())
437 }
438 fn end(mut self) -> Result<Self::Ok, Self::Error> {
439 check_length(
440 "tuple variant",
441 &self.config,
442 self.expected_len,
443 &mut self.values,
444 )?;
445
446 Ok(Save::TupleVariant {
447 variant: self.variant,
448 values: self.values,
449 })
450 }
451}
452pub struct SerializeMap<E: ErrorDiscipline> {
453 expected_len: Option<usize>,
454 config: Config<E>,
455 keys: Vec<Save<'static, E::SaveError>>,
456 values: Vec<Save<'static, E::SaveError>>,
457}
458impl<E> serde::ser::SerializeMap for SerializeMap<E>
459where
460 E: ErrorDiscipline,
461{
462 type Ok = Save<'static, E::SaveError>;
463 type Error = Error;
464 fn serialize_key<T: ?Sized + serde::Serialize>(&mut self, key: &T) -> Result<(), Self::Error> {
465 self.keys.push(E::handle(key.serialize(Serializer {
466 config: self.config,
467 }))?);
468 Ok(())
469 }
470 fn serialize_value<T: ?Sized + serde::Serialize>(
471 &mut self,
472 value: &T,
473 ) -> Result<(), Self::Error> {
474 self.values.push(E::handle(value.serialize(Serializer {
475 config: self.config,
476 }))?);
477 Ok(())
478 }
479 fn end(self) -> Result<Self::Ok, Self::Error> {
480 let n_keys = self.keys.len();
481 let n_values = self.values.len();
482 let mut map = Vec::with_capacity(cmp::max(n_keys, n_values));
483 let mut keys = self.keys.into_iter();
484 let mut values = self.values.into_iter();
485 loop {
486 let e = || Error {
487 msg: format!(
488 "protocol error: map has {} keys and {} values",
489 n_keys, n_values
490 ),
491 protocol: true,
492 };
493 match (keys.next(), values.next()) {
494 (None, None) => {
495 if let Some(expected) = self.expected_len {
496 if self.config.protocol_errors && expected != map.len() {
497 let e = || Error {
498 msg: format!(
499 "protocol error: expected a map of length {}, got {}",
500 expected,
501 map.len()
502 ),
503 protocol: true,
504 };
505 map.push((E::handle(Err(e()))?, E::handle(Err(e()))?))
506 }
507 }
508 return Ok(Save::Map(map));
509 }
510 (Some(key), Some(value)) => map.push((key, value)),
511 (None, Some(value)) => map.push((E::handle(Err(e()))?, value)),
512 (Some(key), None) => map.push((key, E::handle(Err(e()))?)),
513 }
514 }
515 }
516}
517
518fn check<E>(
519 what: &str,
520 config: &Config<E>,
521 expected_len: usize,
522 fields: &mut Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
523) -> Result<(), Error>
524where
525 E: ErrorDiscipline,
526{
527 if config.protocol_errors {
528 let mut seen = BTreeSet::new();
529 let mut dups = Vec::new();
530 for name in fields.iter().map(|(it, _)| it) {
531 let new = seen.insert(*name);
532 if !new {
533 dups.push(*name)
534 }
535 }
536 if !dups.is_empty() {
537 let e = Error {
538 msg: format!(
539 "protocol error: {} has duplicate field names: {}",
540 what,
541 dups.join(", ")
542 ),
543 protocol: true,
544 };
545 fields.push(("!error", Some(E::handle(Err(e))?)))
546 }
547
548 let actual = fields.len();
549 if expected_len != actual {
550 let e = Error {
551 msg: format!(
552 "protocol error: expected a {} of length {}, got {}",
553 what, expected_len, actual
554 ),
555 protocol: true,
556 };
557 fields.push(("!error", Some(E::handle(Err(e))?)))
558 }
559 }
560 Ok(())
561}
562
563pub struct SerializeStruct<E: ErrorDiscipline> {
564 expected_len: usize,
565 config: Config<E>,
566 name: &'static str,
567 fields: Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
568}
569impl<E> serde::ser::SerializeStruct for SerializeStruct<E>
570where
571 E: ErrorDiscipline,
572{
573 type Ok = Save<'static, E::SaveError>;
574 type Error = Error;
575 fn serialize_field<T: ?Sized + serde::Serialize>(
576 &mut self,
577 key: &'static str,
578 value: &T,
579 ) -> Result<(), Self::Error> {
580 self.fields.push((
581 key,
582 Some(E::handle(value.serialize(Serializer {
583 config: self.config,
584 }))?),
585 ));
586 Ok(())
587 }
588 fn end(mut self) -> Result<Self::Ok, Self::Error> {
589 check("struct", &self.config, self.expected_len, &mut self.fields)?;
590 Ok(Save::Struct {
591 name: self.name,
592 fields: self.fields,
593 })
594 }
595 fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
596 self.fields.push((key, None));
597 Ok(())
598 }
599}
600pub struct SerializeStructVariant<E: ErrorDiscipline> {
601 expected_len: usize,
602 config: Config<E>,
603 variant: Variant<'static>,
604 fields: Vec<(&'static str, Option<Save<'static, E::SaveError>>)>,
605}
606impl<E> serde::ser::SerializeStructVariant for SerializeStructVariant<E>
607where
608 E: ErrorDiscipline,
609{
610 type Ok = Save<'static, E::SaveError>;
611 type Error = Error;
612 fn serialize_field<T: ?Sized + serde::Serialize>(
613 &mut self,
614 key: &'static str,
615 value: &T,
616 ) -> Result<(), Self::Error> {
617 self.fields.push((
618 key,
619 Some(E::handle(value.serialize(Serializer {
620 config: self.config,
621 }))?),
622 ));
623 Ok(())
624 }
625 fn end(mut self) -> Result<Self::Ok, Self::Error> {
626 check("struct", &self.config, self.expected_len, &mut self.fields)?;
627
628 Ok(Save::StructVariant {
629 variant: self.variant,
630 fields: self.fields,
631 })
632 }
633 fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
634 self.fields.push((key, None));
635 Ok(())
636 }
637}