Skip to main content

serde_reflection/
trace.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    de::Deserializer,
6    error::{Error, Result},
7    format::*,
8    ser::Serializer,
9    value::Value,
10};
11use erased_discriminant::Discriminant;
12use once_cell::sync::Lazy;
13use serde::{de::DeserializeSeed, Deserialize, Serialize};
14use std::any::TypeId;
15use std::collections::BTreeMap;
16
17/// A map of container formats.
18pub type Registry = BTreeMap<String, ContainerFormat>;
19
20/// Structure to drive the tracing of Serde serialization and deserialization.
21/// This typically aims at computing a `Registry`.
22#[derive(Debug)]
23pub struct Tracer {
24    /// Hold configuration options.
25    pub(crate) config: TracerConfig,
26
27    /// Formats of the named containers discovered so far, while tracing
28    /// serialization and/or deserialization.
29    pub(crate) registry: Registry,
30
31    /// Enums that have detected to be yet incomplete (i.e. missing variants)
32    /// while tracing deserialization.
33    pub(crate) incomplete_enums: BTreeMap<String, IncompleteEnumReason>,
34
35    /// Discriminant associated with each variant of each enum.
36    pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>,
37}
38
39/// Type of untraced enum variants
40#[derive(Copy, Clone, Debug)]
41pub enum IncompleteEnumReason {
42    /// There are variant names that have not yet been traced.
43    NamedVariantsRemaining,
44    /// There are variant numbers that have not yet been traced.
45    IndexedVariantsRemaining,
46}
47
48#[derive(Eq, PartialEq, Ord, PartialOrd, Debug)]
49pub(crate) enum VariantId<'a> {
50    Index(u32),
51    Name(&'a str),
52}
53
54/// User inputs, aka "samples", recorded during serialization.
55/// This will help passing user-defined checks during deserialization.
56#[derive(Debug, Default)]
57pub struct Samples {
58    pub(crate) values: BTreeMap<&'static str, Value>,
59}
60
61impl Samples {
62    /// Create a new structure to hold value samples.
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Obtain a (serialized) sample.
68    pub fn value(&self, name: &'static str) -> Option<&Value> {
69        self.values.get(name)
70    }
71}
72
73/// Configuration object to create a tracer.
74#[derive(Debug)]
75pub struct TracerConfig {
76    pub(crate) is_human_readable: bool,
77    pub(crate) record_samples_for_newtype_structs: bool,
78    pub(crate) record_samples_for_tuple_structs: bool,
79    pub(crate) record_samples_for_structs: bool,
80    pub(crate) default_bool_value: bool,
81    pub(crate) default_u8_value: u8,
82    pub(crate) default_u16_value: u16,
83    pub(crate) default_u32_value: u32,
84    pub(crate) default_u64_value: u64,
85    pub(crate) default_u128_value: u128,
86    pub(crate) default_i8_value: i8,
87    pub(crate) default_i16_value: i16,
88    pub(crate) default_i32_value: i32,
89    pub(crate) default_i64_value: i64,
90    pub(crate) default_i128_value: i128,
91    pub(crate) default_f32_value: f32,
92    pub(crate) default_f64_value: f64,
93    pub(crate) default_char_value: char,
94    pub(crate) default_borrowed_str_value: &'static str,
95    pub(crate) default_string_value: String,
96    pub(crate) default_borrowed_bytes_value: &'static [u8],
97    pub(crate) default_byte_buf_value: Vec<u8>,
98}
99
100impl Default for TracerConfig {
101    /// Create a new structure to hold value samples.
102    fn default() -> Self {
103        Self {
104            is_human_readable: false,
105            record_samples_for_newtype_structs: true,
106            record_samples_for_tuple_structs: false,
107            record_samples_for_structs: false,
108            default_bool_value: false,
109            default_u8_value: 0,
110            default_u16_value: 0,
111            default_u32_value: 0,
112            default_u64_value: 0,
113            default_u128_value: 0,
114            default_i8_value: 0,
115            default_i16_value: 0,
116            default_i32_value: 0,
117            default_i64_value: 0,
118            default_i128_value: 0,
119            default_f32_value: 0.0,
120            default_f64_value: 0.0,
121            default_char_value: 'A',
122            default_borrowed_str_value: "",
123            default_string_value: String::new(),
124            default_borrowed_bytes_value: b"",
125            default_byte_buf_value: Vec::new(),
126        }
127    }
128}
129
130macro_rules! define_default_value_setter {
131    ($method:ident, $ty:ty) => {
132        /// The default serialized value for this primitive type.
133        pub fn $method(mut self, value: $ty) -> Self {
134            self.$method = value;
135            self
136        }
137    };
138}
139
140impl TracerConfig {
141    /// Whether to trace the human readable encoding of (de)serialization.
142    #[allow(clippy::wrong_self_convention)]
143    pub fn is_human_readable(mut self, value: bool) -> Self {
144        self.is_human_readable = value;
145        self
146    }
147
148    /// Record samples of newtype structs during serialization and inject them during deserialization.
149    pub fn record_samples_for_newtype_structs(mut self, value: bool) -> Self {
150        self.record_samples_for_newtype_structs = value;
151        self
152    }
153
154    /// Record samples of tuple structs during serialization and inject them during deserialization.
155    pub fn record_samples_for_tuple_structs(mut self, value: bool) -> Self {
156        self.record_samples_for_tuple_structs = value;
157        self
158    }
159
160    /// Record samples of (regular) structs during serialization and inject them during deserialization.
161    pub fn record_samples_for_structs(mut self, value: bool) -> Self {
162        self.record_samples_for_structs = value;
163        self
164    }
165
166    define_default_value_setter!(default_bool_value, bool);
167    define_default_value_setter!(default_u8_value, u8);
168    define_default_value_setter!(default_u16_value, u16);
169    define_default_value_setter!(default_u32_value, u32);
170    define_default_value_setter!(default_u64_value, u64);
171    define_default_value_setter!(default_u128_value, u128);
172    define_default_value_setter!(default_i8_value, i8);
173    define_default_value_setter!(default_i16_value, i16);
174    define_default_value_setter!(default_i32_value, i32);
175    define_default_value_setter!(default_i64_value, i64);
176    define_default_value_setter!(default_i128_value, i128);
177    define_default_value_setter!(default_f32_value, f32);
178    define_default_value_setter!(default_f64_value, f64);
179    define_default_value_setter!(default_char_value, char);
180    define_default_value_setter!(default_borrowed_str_value, &'static str);
181    define_default_value_setter!(default_string_value, String);
182    define_default_value_setter!(default_borrowed_bytes_value, &'static [u8]);
183    define_default_value_setter!(default_byte_buf_value, Vec<u8>);
184}
185
186impl Tracer {
187    /// Start tracing deserialization.
188    pub fn new(config: TracerConfig) -> Self {
189        Self {
190            config,
191            registry: BTreeMap::new(),
192            incomplete_enums: BTreeMap::new(),
193            discriminants: BTreeMap::new(),
194        }
195    }
196
197    /// Trace the serialization of a particular value.
198    /// * Nested containers will be added to the tracing registry, indexed by
199    ///   their (non-qualified) name.
200    /// * Sampled Rust values will be inserted into `samples` to benefit future calls
201    ///   to the `trace_type_*` methods.
202    pub fn trace_value<T>(&mut self, samples: &mut Samples, value: &T) -> Result<(Format, Value)>
203    where
204        T: ?Sized + Serialize,
205    {
206        let serializer = Serializer::new(self, samples);
207        let (mut format, sample) = value.serialize(serializer)?;
208        format.reduce();
209        Ok((format, sample))
210    }
211
212    /// Trace a single deserialization of a particular type.
213    /// * Nested containers will be added to the tracing registry, indexed by
214    ///   their (non-qualified) name.
215    /// * As a byproduct of deserialization, we also return a value of type `T`.
216    /// * Tracing deserialization of a type may fail if this type or some dependencies
217    ///   have implemented a custom deserializer that validates data. The solution is
218    ///   to make sure that `samples` holds enough sampled Rust values to cover all the
219    ///   custom types.
220    pub fn trace_type_once<'de, T>(&mut self, samples: &'de Samples) -> Result<(Format, T)>
221    where
222        T: Deserialize<'de>,
223    {
224        let mut format = Format::unknown();
225        let deserializer = Deserializer::new(self, samples, &mut format);
226        let value = T::deserialize(deserializer)?;
227        format.reduce();
228        Ok((format, value))
229    }
230
231    /// Same as `trace_type_once` for seeded deserialization.
232    pub fn trace_type_once_with_seed<'de, S>(
233        &mut self,
234        samples: &'de Samples,
235        seed: S,
236    ) -> Result<(Format, S::Value)>
237    where
238        S: DeserializeSeed<'de>,
239    {
240        let mut format = Format::unknown();
241        let deserializer = Deserializer::new(self, samples, &mut format);
242        let value = seed.deserialize(deserializer)?;
243        format.reduce();
244        Ok((format, value))
245    }
246
247    /// Read the status of an enum and reset the value.
248    pub fn check_incomplete_enum(&mut self, name: &str) -> Option<IncompleteEnumReason> {
249        self.incomplete_enums.remove(name)
250    }
251
252    /// Same as `trace_type_once` but if `T` is an enum, we repeat the process
253    /// until all variants of `T` are covered.
254    /// We accumulate and return all the sampled values at the end.
255    pub fn trace_type<'de, T>(&mut self, samples: &'de Samples) -> Result<(Format, Vec<T>)>
256    where
257        T: Deserialize<'de>,
258    {
259        let mut values = Vec::new();
260        loop {
261            let (format, value) = self.trace_type_once::<T>(samples)?;
262            values.push(value);
263            if let Format::TypeName(name) = &format {
264                if let Some(reason) = self.check_incomplete_enum(name) {
265                    // Restart the analysis to find more variants of T.
266                    if let IncompleteEnumReason::NamedVariantsRemaining = reason {
267                        values.pop().unwrap();
268                    }
269                    continue;
270                }
271            }
272            return Ok((format, values));
273        }
274    }
275
276    /// Trace a type `T` that is simple enough that no samples of values are needed.
277    /// * If `T` is an enum, the tracing iterates until all variants of `T` are covered.
278    /// * Accumulate and return all the sampled values at the end.
279    ///   This is merely a shortcut for `self.trace_type` with a fixed empty set of samples.
280    pub fn trace_simple_type<'de, T>(&mut self) -> Result<(Format, Vec<T>)>
281    where
282        T: Deserialize<'de>,
283    {
284        static SAMPLES: Lazy<Samples> = Lazy::new(Samples::new);
285        self.trace_type(&SAMPLES)
286    }
287
288    /// Same as `trace_type` for seeded deserialization.
289    pub fn trace_type_with_seed<'de, S>(
290        &mut self,
291        samples: &'de Samples,
292        seed: S,
293    ) -> Result<(Format, Vec<S::Value>)>
294    where
295        S: DeserializeSeed<'de> + Clone,
296    {
297        let mut values = Vec::new();
298        loop {
299            let (format, value) = self.trace_type_once_with_seed(samples, seed.clone())?;
300            values.push(value);
301            if let Format::TypeName(name) = &format {
302                if let Some(reason) = self.check_incomplete_enum(name) {
303                    // Restart the analysis to find more variants of T.
304                    if let IncompleteEnumReason::NamedVariantsRemaining = reason {
305                        values.pop().unwrap();
306                    }
307                    continue;
308                }
309            }
310            return Ok((format, values));
311        }
312    }
313
314    /// Finish tracing and recover a map of normalized formats.
315    /// Returns an error if we detect incompletely traced types.
316    /// This may happen in a few of cases:
317    /// * We traced serialization of user-provided values but we are still missing the content
318    ///   of an option type, the content of a sequence type, the key or the value of a dictionary type.
319    /// * We traced deserialization of an enum type but we detect that some enum variants are still missing.
320    pub fn registry(self) -> Result<Registry> {
321        let mut registry = self.registry;
322        for (name, format) in registry.iter_mut() {
323            format
324                .normalize()
325                .map_err(|_| Error::UnknownFormatInContainer(name.clone()))?;
326        }
327        if self.incomplete_enums.is_empty() {
328            Ok(registry)
329        } else {
330            Err(Error::MissingVariants(
331                self.incomplete_enums.into_keys().collect(),
332            ))
333        }
334    }
335
336    /// Same as registry but always return a value, even if we detected issues.
337    /// This should only be use for debugging.
338    pub fn registry_unchecked(self) -> Registry {
339        let mut registry = self.registry;
340        for format in registry.values_mut() {
341            format.normalize().unwrap_or(());
342        }
343        registry
344    }
345
346    pub(crate) fn record_container(
347        &mut self,
348        samples: &mut Samples,
349        name: &'static str,
350        format: ContainerFormat,
351        value: Value,
352        record_value: bool,
353    ) -> Result<(Format, Value)> {
354        self.registry.entry(name.to_string()).unify(format)?;
355        if record_value {
356            samples.values.insert(name, value.clone());
357        }
358        Ok((Format::TypeName(name.into()), value))
359    }
360
361    pub(crate) fn record_variant(
362        &mut self,
363        samples: &mut Samples,
364        name: &'static str,
365        variant_index: u32,
366        variant_name: &'static str,
367        variant: VariantFormat,
368        variant_value: Value,
369    ) -> Result<(Format, Value)> {
370        let mut variants = BTreeMap::new();
371        variants.insert(
372            variant_index,
373            Named {
374                name: variant_name.into(),
375                value: variant,
376            },
377        );
378        let format = ContainerFormat::Enum(variants);
379        let value = Value::Variant(variant_index, Box::new(variant_value));
380        self.record_container(samples, name, format, value, false)
381    }
382
383    pub(crate) fn get_sample<'de, 'a>(
384        &'a self,
385        samples: &'de Samples,
386        name: &'static str,
387    ) -> Option<(&'a ContainerFormat, &'de Value)> {
388        match samples.value(name) {
389            Some(value) => {
390                let format = self
391                    .registry
392                    .get(name)
393                    .expect("recorded containers should have a format already");
394                Some((format, value))
395            }
396            None => None,
397        }
398    }
399}