Skip to main content

tor_config/
flatten.rs

1//! Similar to `#[serde(flatten)]` but works with [`serde_ignored`]
2//!
3//! Our approach to deserialize a [`Flatten`] is as follows:
4//!
5//!  * We tell the input data format (underlying deserializer) that we want a map.
6//!  * In our visitor, we visit each key in the map in order
7//!  * For each key, we consult `Flattenable::has_field` to find out which child it's in
8//!    (fields in T shadow fields in U, as with serde),
9//!    and store the key and the value in the appropriate [`Portion`].
10//!    (We must store the value as a [`serde_value::Value`]
11//!    since we don't know what type it should be,
12//!    and can't know until we are ready to enter T and U's [`Deserialize`] impls.)
13//!  * If it's in neither T nor U, we explicitly ignore the value
14//!  * When we've processed all the fields, we call the actual deserialisers for T and U:
15//!    we take on the role of the data format, giving each of T and U a map.
16//!
17//! From the point of view of T and U, we each offer them a subset of the fields,
18//! having already rendered the keys to strings and the values to `Value`.
19//!
20//! From the point of view of the data format (which might be a `serde_ignored` proxy)
21//! we consume the union of the fields, and ignore the rest.
22//!
23//! ### Rationale and alternatives
24//!
25//! The key difficulty is this:
26//! we want to call [`Deserializer::deserialize_ignored_any`]
27//! on our input data format for precisely the fields which neither T nor U want.
28//! We must achieve this somehow using information from T or U.
29//! If we tried to use only the [`Deserialize`] impls,
30//! the only way to detect this is to call their `deserialize` methods
31//! and watch to see if they in turn call `deserialize_ignored_any`.
32//! But we need to be asking each of T and U this question for each field:
33//! the shape of [`MapAccess`] puts the data structure in charge of sequencing.
34//! So we would need to somehow suspend `T`'s deserialisation,
35//! and call `U`'s, and then suspend `U`s, and go back to `T`.
36//!
37//! Other possibilities that seemed worse:
38//!
39//!  * Use threads.
40//!    We could spawn a thread for each of `T` and `U`,
41//!    allowing us to run them in parallel and control their execution flow.
42//!
43//!  * Use coroutines eg. [corosensei](https://lib.rs/crates/corosensei)
44//!    (by Amanieu, author of hashbrown etc.)
45//!
46//!  * Instead of suspending and restarting `T` and `U`'s deserialisation,
47//!    discard the partially-deserialised `T` and `U` and restart them each time
48//!    (with cloned copies of the `Value`s).  This is O(n^2) and involves much boxing.
49//!
50//! # References
51//!
52//!  * Tickets against `serde-ignored`:
53//!    <https://github.com/dtolnay/serde-ignored/issues/17>
54//!    <https://github.com/dtolnay/serde-ignored/issues/10>
55//!
56//!  * Workaround with `HashMap` that doesn't quite work right:
57//!    <https://github.com/dtolnay/serde-ignored/issues/10#issuecomment-1044058310>
58//!    <https://github.com/serde-rs/serde/issues/2176>
59//!
60//!  * Discussion in Tor Project gitlab re Arti configuration:
61//!    <https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/1599#note_2944510>
62
63use std::collections::VecDeque;
64use std::fmt::{self, Display};
65use std::marker::PhantomData;
66use std::mem;
67
68use derive_deftly::{Deftly, define_derive_deftly, derive_deftly_adhoc};
69use paste::paste;
70use serde::de::{self, DeserializeSeed, Deserializer, Error as _, IgnoredAny, MapAccess, Visitor};
71use serde::{Deserialize, Serialize, Serializer};
72use serde_value::Value;
73use thiserror::Error;
74
75// Must come first so we can refer to it in docs
76define_derive_deftly! {
77    /// Derives [`Flattenable`] for a struct
78    ///
79    /// # Limitations
80    ///
81    /// Some serde attributes might not be supported.
82    /// For example, ones which make the type no longer deserialize as a named fields struct.
83    /// This will be detected by a macro-generated always-failing test case.
84    ///
85    /// Most serde attributes (eg field renaming and ignoring) will be fine.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use serde::{Serialize, Deserialize};
91    /// use derive_deftly::Deftly;
92    /// use tor_config::derive_deftly_template_Flattenable;
93    ///
94    /// #[derive(Serialize, Deserialize, Debug, Deftly)]
95    /// #[derive_deftly(Flattenable)]
96    /// struct A {
97    ///     a: i32,
98    /// }
99    /// ```
100    //
101    // Note re semver:
102    //
103    // We re-export derive-deftly's template engine, in the manner discussed by the d-a docs.
104    // See
105    //  https://docs.rs/derive-deftly/latest/derive_deftly/macro.define_derive_deftly.html#exporting-a-template-for-use-by-other-crates
106    //
107    // The semantic behaviour of the template *does* have semver implications.
108    export Flattenable for struct, expect items:
109
110    impl<$tgens> $crate::Flattenable for $ttype
111    where $twheres {
112        fn has_field(s: &str) -> bool {
113            let fnames = $crate::flattenable_extract_fields::<'_, Self>();
114            IntoIterator::into_iter(fnames).any(|f| *f == s)
115
116        }
117    }
118
119    // Detect if flattenable_extract_fields panics
120    #[test]
121    fn $<flattenable_test_ ${snake_case $tname}>() {
122        // Using $ttype::has_field avoids writing out again
123        // the call to flattenable_extract_fields, with all its generics,
124        // and thereby ensures that we didn't have a mismatch that
125        // allows broken impls to slip through.
126        // (We know the type is at least similar because we go via the Flattenable impl.)
127        let _: bool = <$ttype as $crate::Flattenable>::has_field("");
128    }
129}
130pub use derive_deftly_template_Flattenable;
131
132/// Helper for flattening deserialisation, compatible with [`serde_ignored`]
133///
134/// A combination of two structs `T` and `U`.
135///
136/// The serde representation flattens both structs into a single, larger, struct.
137///
138/// Furthermore, unlike plain use of `#[serde(flatten)]`,
139/// `serde_ignored` will still detect fields which appear in serde input
140/// but which form part of neither `T` nor `U`.
141///
142/// `T` and `U` must both be [`Flattenable`].
143/// Usually that trait should be derived with
144/// the [`Flattenable macro`](derive_deftly_template_Flattenable).
145///
146/// If it's desired to combine more than two structs, `Flatten` can be nested.
147///
148/// # Limitations
149///
150/// Field name overlaps are not detected.
151/// Fields which appear in both structs
152/// will be processed as part of `T` during deserialization.
153/// They will be internally presented as duplicate fields during serialization,
154/// with the outcome depending on the data format implementation.
155///
156/// # Example
157///
158/// ```
159/// use serde::{Serialize, Deserialize};
160/// use derive_deftly::Deftly;
161/// use tor_config::{Flatten, derive_deftly_template_Flattenable};
162///
163/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
164/// #[derive_deftly(Flattenable)]
165/// struct A {
166///     a: i32,
167/// }
168///
169/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
170/// #[derive_deftly(Flattenable)]
171/// struct B {
172///     b: String,
173/// }
174///
175/// let combined: Flatten<A,B> = toml::from_str(r#"
176///     a = 42
177///     b = "hello"
178/// "#).unwrap();
179///
180/// assert_eq!(
181///    combined,
182///    Flatten(A { a: 42 }, B { b: "hello".into() }),
183/// );
184/// ```
185//
186// We derive Deftly on Flatten itself so we can use
187// derive_deftly_adhoc! to iterate over Flatten's two fields.
188// This avoids us accidentally (for example) checking T's fields for passing to U.
189#[derive(Deftly, Debug, Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq, Default)]
190#[derive_deftly_adhoc]
191#[allow(clippy::exhaustive_structs)]
192pub struct Flatten<T, U>(pub T, pub U);
193
194/// Types that can be used with [`Flatten`]
195///
196/// Usually, derived with
197/// the [`Flattenable derive-deftly macro`](derive_deftly_template_Flattenable).
198pub trait Flattenable {
199    /// Does this type have a field named `s` ?
200    fn has_field(f: &str) -> bool;
201}
202
203//========== local helper macros ==========
204
205/// Implement `deserialize_$what` as a call to `deserialize_any`.
206///
207/// `$args`, if provided, are any other formal arguments, not including the `Visitor`
208macro_rules! call_any { { $what:ident $( $args:tt )* } => { paste!{
209    fn [<deserialize_ $what>]<V>(self $( $args )*, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: Visitor<'de>,
212    {
213        self.deserialize_any(visitor)
214    }
215} } }
216
217/// Implement most `deserialize_*` as calls to `deserialize_any`.
218///
219/// The exceptions are the ones we need to handle specially in any of our types,
220/// namely `any` itself and `struct`.
221macro_rules! call_any_for_rest { {} => {
222    call_any!(map);
223    call_any!(bool);
224    call_any!(byte_buf);
225    call_any!(bytes);
226    call_any!(char);
227    call_any!(f32);
228    call_any!(f64);
229    call_any!(i128);
230    call_any!(i16);
231    call_any!(i32);
232    call_any!(i64);
233    call_any!(i8);
234    call_any!(identifier);
235    call_any!(ignored_any);
236    call_any!(option);
237    call_any!(seq);
238    call_any!(str);
239    call_any!(string);
240    call_any!(u128);
241    call_any!(u16);
242    call_any!(u32);
243    call_any!(u64);
244    call_any!(u8);
245    call_any!(unit);
246
247    call_any!(enum, _: &'static str, _: FieldList);
248    call_any!(newtype_struct, _: &'static str );
249    call_any!(tuple, _: usize );
250    call_any!(tuple_struct, _: &'static str, _: usize );
251    call_any!(unit_struct, _: &'static str );
252} }
253
254//========== Implementations of Serialize and Flattenable ==========
255
256derive_deftly_adhoc! {
257    Flatten expect items:
258
259    impl<T, U> Serialize for Flatten<T, U>
260    where $( $ftype: Serialize, )
261    {
262        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
263        where S: Serializer
264        {
265            /// version of outer `Flatten` containing references
266            ///
267            /// We give it the same name because the name is visible via serde
268            ///
269            /// The problems with `#[serde(flatten)]` don't apply to serialisation,
270            /// because we're not trying to track ignored fields.
271            /// But we can't just apply `#[serde(flatten)]` to `Flatten`
272            /// since it doesn't work with tuple structs.
273            #[derive(Serialize)]
274            struct Flatten<'r, T, U> {
275              $(
276                #[serde(flatten)]
277                $fpatname: &'r $ftype,
278              )
279            }
280
281            Flatten {
282              $(
283                $fpatname: &self.$fname,
284              )
285            }
286            .serialize(serializer)
287        }
288    }
289
290    /// `Flatten` may be nested
291    impl<T, U> Flattenable for Flatten<T, U>
292    where $( $ftype: Flattenable, )
293    {
294        fn has_field(f: &str) -> bool {
295            $(
296                $ftype::has_field(f)
297                    ||
298              )
299                false
300        }
301    }
302}
303
304//========== Deserialize implementation ==========
305
306/// The keys and values we are to direct to a particular child
307///
308/// See the module-level comment for the algorithm.
309#[derive(Default)]
310struct Portion(VecDeque<(String, Value)>);
311
312/// [`de::Visitor`] for `Flatten`
313struct FlattenVisitor<T, U>(PhantomData<(T, U)>);
314
315/// Wrapper for a field name, impls [`de::Deserializer`]
316struct Key(String);
317
318/// Type alias for reified error
319///
320/// [`serde_value::DeserializerError`] has one variant
321/// for each of the constructors of [`de::Error`].
322type FlattenError = serde_value::DeserializerError;
323
324//----- part 1: disassembly -----
325
326derive_deftly_adhoc! {
327    Flatten expect items:
328
329    // where constraint on the Deserialize impl
330    ${define FLATTENABLE $( $ftype: Deserialize<'de> + Flattenable, )}
331
332    impl<'de, T, U> Deserialize<'de> for Flatten<T, U>
333    where $FLATTENABLE
334    {
335        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
336        where D: Deserializer<'de>
337        {
338            deserializer.deserialize_map(FlattenVisitor(PhantomData))
339        }
340    }
341
342    impl<'de, T, U> Visitor<'de> for FlattenVisitor<T,U>
343    where $FLATTENABLE
344    {
345        type Value = Flatten<T, U>;
346
347        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348            write!(f, "map (for struct)")
349        }
350
351        fn visit_map<A>(self, mut mapa: A) -> Result<Self::Value, A::Error>
352        where A: MapAccess<'de>
353        {
354            // See the module-level comment for an explanation.
355
356            // $P is a local variable named after T/U: `p_t` or `p_u`, as appropriate
357            ${define P $<p_ $fname>}
358
359            ${for fields { let mut $P = Portion::default(); }}
360
361            #[allow(clippy::suspicious_else_formatting)] // this is the least bad layout
362            while let Some(k) = mapa.next_key::<String>()? {
363              $(
364                 if $ftype::has_field(&k) {
365                    let v: Value = mapa.next_value()?;
366                    $P.0.push_back((k, v));
367                    continue;
368                }
369                else
370              )
371                {
372                     let _: IgnoredAny = mapa.next_value()?;
373                }
374            }
375
376            Flatten::assemble( ${for fields { $P, }} )
377                .map_err(A::Error::custom)
378        }
379    }
380}
381
382//----- part 2: reassembly -----
383
384derive_deftly_adhoc! {
385    Flatten expect items:
386
387    impl<'de, T, U> Flatten<T, U>
388    where $( $ftype: Deserialize<'de>, )
389    {
390        /// Assemble a `Flatten` out of the partition of its keys and values
391        ///
392        /// Uses `Portion`'s `Deserializer` impl and T and U's `Deserialize`
393        fn assemble(
394          $(
395            $fpatname: Portion,
396          )
397        ) -> Result<Self, FlattenError> {
398            Ok(Flatten(
399              $(
400                $ftype::deserialize($fpatname)?,
401              )
402            ))
403        }
404    }
405}
406
407impl<'de> Deserializer<'de> for Portion {
408    type Error = FlattenError;
409
410    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
411    where
412        V: Visitor<'de>,
413    {
414        visitor.visit_map(self)
415    }
416
417    call_any!(struct, _: &'static str, _: FieldList);
418    call_any_for_rest!();
419}
420
421impl<'de> MapAccess<'de> for Portion {
422    type Error = FlattenError;
423
424    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
425    where
426        K: DeserializeSeed<'de>,
427    {
428        let Some(entry) = self.0.get_mut(0) else {
429            return Ok(None);
430        };
431        let k = mem::take(&mut entry.0);
432        let k: K::Value = seed.deserialize(Key(k))?;
433        Ok(Some(k))
434    }
435
436    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
437    where
438        V: DeserializeSeed<'de>,
439    {
440        let v = self
441            .0
442            .pop_front()
443            .expect("next_value called inappropriately")
444            .1;
445        let r = seed.deserialize(v)?;
446        Ok(r)
447    }
448}
449
450impl<'de> Deserializer<'de> for Key {
451    type Error = FlattenError;
452
453    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
454    where
455        V: Visitor<'de>,
456    {
457        visitor.visit_string(self.0)
458    }
459
460    call_any!(struct, _: &'static str, _: FieldList);
461    call_any_for_rest!();
462}
463
464//========== Field extractor ==========
465
466/// List of fields, appears in several APIs here
467type FieldList = &'static [&'static str];
468
469/// Stunt "data format" which we use for extracting fields for derived `Flattenable` impls
470///
471/// The field extraction works as follows:
472///  * We ask serde to deserialize `$ttype` from a `FieldExtractor`
473///  * We expect the serde-macro-generated `Deserialize` impl to call `deserialize_struct`
474///  * We return the list of fields to match up as an error
475struct FieldExtractor;
476
477/// Error resulting from successful operation of a [`FieldExtractor`]
478///
479/// Existence of this error is a *success*.
480/// Unexpected behaviour by the type's serde implementation causes panics, not errors.
481#[derive(Error, Debug)]
482#[error("Flattenable macro test gave error, so test passed successfully")]
483struct FieldExtractorSuccess(FieldList);
484
485/// Extract fields of a struct, as viewed by `serde`
486///
487/// # Performance
488///
489/// In release builds, is very fast - all the serde nonsense boils off.
490/// In debug builds, maybe a hundred instructions, so not ideal,
491/// but it is at least O(1) since it doesn't have any loops.
492///
493/// # STABILITY WARNING
494///
495/// This function is `pub` but it is `#[doc(hidden)]`.
496/// The only legitimate use is via the `Flattenable` macro.
497/// There are **NO SEMVER GUARANTEES**
498///
499/// # Panics
500///
501/// Will panic on types whose serde field list cannot be simply extracted via serde,
502/// which will include things that aren't named fields structs,
503/// might include types decorated with unusual serde annotations.
504pub fn flattenable_extract_fields<'de, T: Deserialize<'de>>() -> FieldList {
505    let notional_input = FieldExtractor;
506    let FieldExtractorSuccess(fields) = T::deserialize(notional_input)
507        .map(|_| ())
508        .expect_err("unexpected success deserializing from FieldExtractor!");
509    fields
510}
511
512impl de::Error for FieldExtractorSuccess {
513    fn custom<E>(e: E) -> Self
514    where
515        E: Display,
516    {
517        panic!("Flattenable macro test failed - some *other* serde error: {e}");
518    }
519}
520
521impl<'de> Deserializer<'de> for FieldExtractor {
522    type Error = FieldExtractorSuccess;
523
524    fn deserialize_struct<V>(
525        self,
526        _name: &'static str,
527        fields: FieldList,
528        _visitor: V,
529    ) -> Result<V::Value, Self::Error>
530    where
531        V: Visitor<'de>,
532    {
533        Err(FieldExtractorSuccess(fields))
534    }
535
536    fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
537    where
538        V: Visitor<'de>,
539    {
540        panic!("test failed: Flattennable misimplemented by macros!");
541    }
542
543    call_any_for_rest!();
544}
545
546//========== tests ==========
547
548#[cfg(test)]
549mod test {
550    // @@ begin test lint list maintained by maint/add_warning @@
551    #![allow(clippy::bool_assert_comparison)]
552    #![allow(clippy::clone_on_copy)]
553    #![allow(clippy::dbg_macro)]
554    #![allow(clippy::mixed_attributes_style)]
555    #![allow(clippy::print_stderr)]
556    #![allow(clippy::print_stdout)]
557    #![allow(clippy::single_char_pattern)]
558    #![allow(clippy::unwrap_used)]
559    #![allow(clippy::unchecked_time_subtraction)]
560    #![allow(clippy::useless_vec)]
561    #![allow(clippy::needless_pass_by_value)]
562    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
563    use super::*;
564
565    use std::collections::HashMap;
566
567    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
568    #[derive_deftly(Flattenable)]
569    struct A {
570        a: i32,
571        m: HashMap<String, String>,
572    }
573
574    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
575    #[derive_deftly(Flattenable)]
576    struct B {
577        b: i32,
578        v: Vec<String>,
579    }
580
581    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
582    #[derive_deftly(Flattenable)]
583    struct C {
584        c: HashMap<String, String>,
585    }
586
587    const TEST_INPUT: &str = r#"
588        a = 42
589
590        m.one = "unum"
591        m.two = "bis"
592
593        b = 99
594        v = ["hi", "ho"]
595
596        spurious = 66
597
598        c.zed = "final"
599    "#;
600
601    fn test_input() -> toml::Value {
602        toml::from_str(TEST_INPUT).unwrap()
603    }
604    fn simply<'de, T: Deserialize<'de>>() -> T {
605        test_input().try_into().unwrap()
606    }
607    fn with_ignored<'de, T: Deserialize<'de>>() -> (T, Vec<String>) {
608        let mut ignored = vec![];
609        let f = serde_ignored::deserialize(
610            test_input(), //
611            |path| ignored.push(path.to_string()),
612        )
613        .unwrap();
614        (f, ignored)
615    }
616
617    #[test]
618    fn plain() {
619        let f: Flatten<A, B> = test_input().try_into().unwrap();
620        assert_eq!(f, Flatten(simply(), simply()));
621    }
622
623    #[test]
624    fn ignored() {
625        let (f, ignored) = with_ignored::<Flatten<A, B>>();
626        assert_eq!(f, simply());
627        assert_eq!(ignored, ["c", "spurious"]);
628    }
629
630    #[test]
631    fn nested() {
632        let (f, ignored) = with_ignored::<Flatten<A, Flatten<B, C>>>();
633        assert_eq!(f, simply());
634        assert_eq!(ignored, ["spurious"]);
635    }
636
637    #[test]
638    fn ser() {
639        let f: Flatten<A, Flatten<B, C>> = simply();
640
641        assert_eq!(
642            serde_json::to_value(f).unwrap(),
643            serde_json::json!({
644                "a": 42,
645                "m": {
646                    "one": "unum",
647                    "two": "bis"
648                },
649                "b": 99,
650                "v": [
651                    "hi",
652                    "ho"
653                ],
654                "c": {
655                    "zed": "final"
656                }
657            }),
658        );
659    }
660
661    /// This function exists only so we can disassemble it.
662    ///
663    /// To see what the result looks like in a release build:
664    ///
665    ///  * `RUSTFLAGS=-g cargo test -p tor-config --all-features --locked --release -- --nocapture flattenable_extract_fields_a_test`
666    ///  * Observe the binary that's run, eg `Running unittests src/lib.rs (target/release/deps/tor_config-d4c4f29c45a0a3f9)`
667    ///  * Disassemble it `objdump -d target/release/deps/tor_config-d4c4f29c45a0a3f9`
668    ///  * Search for this function: `less +/'28flattenable_extract_fields_a.*:'`
669    ///
670    /// At the time of writing, the result is three instructions:
671    /// load the address of the list, load a register with the constant 2 (the length),
672    /// return.
673    fn flattenable_extract_fields_a() -> FieldList {
674        flattenable_extract_fields::<'_, A>()
675    }
676
677    #[test]
678    fn flattenable_extract_fields_a_test() {
679        use std::hint::black_box;
680        let f: fn() -> _ = black_box(flattenable_extract_fields_a);
681        eprintln!("{:?}", f());
682    }
683}