Skip to main content

tor_config/
flatten.rs

1//! Similar to `#[serde(flatten)]` but works with [`serde_ignored`]
2//!
3//! Our approach to deserialize a [`Flatten`] is as follows:
4//!
5//!  * We tell the input data format (underlying deserializer) that we want a map.
6//!  * In our visitor, we visit each key in the map in order
7//!  * For each key, we consult `Flattenable::has_field` to find out which child it's in
8//!    (fields in T shadow fields in U, as with serde),
9//!    and store the key and the value in the appropriate [`Portion`].
10//!    (We must store the value as a [`serde_value::Value`]
11//!    since we don't know what type it should be,
12//!    and can't know until we are ready to enter T and U's [`Deserialize`] impls.)
13//!  * If it's in neither T nor U, we explicitly ignore the value
14//!  * When we've processed all the fields, we call the actual deserialisers for T and U:
15//!    we take on the role of the data format, giving each of T and U a map.
16//!
17//! From the point of view of T and U, we each offer them a subset of the fields,
18//! having already rendered the keys to strings and the values to `Value`.
19//!
20//! From the point of view of the data format (which might be a `serde_ignored` proxy)
21//! we consume the union of the fields, and ignore the rest.
22//!
23//! ### Rationale and alternatives
24//!
25//! The key difficulty is this:
26//! we want to call [`Deserializer::deserialize_ignored_any`]
27//! on our input data format for precisely the fields which neither T nor U want.
28//! We must achieve this somehow using information from T or U.
29//! If we tried to use only the [`Deserialize`] impls,
30//! the only way to detect this is to call their `deserialize` methods
31//! and watch to see if they in turn call `deserialize_ignored_any`.
32//! But we need to be asking each of T and U this question for each field:
33//! the shape of [`MapAccess`] puts the data structure in charge of sequencing.
34//! So we would need to somehow suspend `T`'s deserialisation,
35//! and call `U`'s, and then suspend `U`s, and go back to `T`.
36//!
37//! Other possibilities that seemed worse:
38//!
39//!  * Use threads.
40//!    We could spawn a thread for each of `T` and `U`,
41//!    allowing us to run them in parallel and control their execution flow.
42//!
43//!  * Use coroutines eg. [corosensei](https://lib.rs/crates/corosensei)
44//!    (by Amanieu, author of hashbrown etc.)
45//!
46//!  * Instead of suspending and restarting `T` and `U`'s deserialisation,
47//!    discard the partially-deserialised `T` and `U` and restart them each time
48//!    (with cloned copies of the `Value`s).  This is O(n^2) and involves much boxing.
49//!
50//! # References
51//!
52//!  * Tickets against `serde-ignored`:
53//!    <https://github.com/dtolnay/serde-ignored/issues/17>
54//!    <https://github.com/dtolnay/serde-ignored/issues/10>
55//!
56//!  * Workaround with `HashMap` that doesn't quite work right:
57//!    <https://github.com/dtolnay/serde-ignored/issues/10#issuecomment-1044058310>
58//!    <https://github.com/serde-rs/serde/issues/2176>
59//!
60//!  * Discussion in Tor Project gitlab re Arti configuration:
61//!    <https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/1599#note_2944510>
62
63use std::collections::VecDeque;
64use std::fmt::{self, Display};
65use std::marker::PhantomData;
66use std::mem;
67
68use derive_deftly::{Deftly, define_derive_deftly, derive_deftly_adhoc};
69use paste::paste;
70use serde::de::{self, DeserializeSeed, Deserializer, Error as _, IgnoredAny, MapAccess, Visitor};
71use serde::{Deserialize, Serialize, Serializer};
72use serde_value::Value;
73use thiserror::Error;
74
75// Must come first so we can refer to it in docs
76define_derive_deftly! {
77    /// Derives [`Flattenable`] for a struct
78    ///
79    /// # Limitations
80    ///
81    /// Some serde attributes might not be supported.
82    /// For example, ones which make the type no longer deserialize as a named fields struct.
83    /// This will be detected by a macro-generated always-failing test case.
84    ///
85    /// Most serde attributes (eg field renaming and ignoring) will be fine.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use serde::{Serialize, Deserialize};
91    /// use derive_deftly::Deftly;
92    /// use tor_config::derive_deftly_template_Flattenable;
93    ///
94    /// #[derive(Serialize, Deserialize, Debug, Deftly)]
95    /// #[derive_deftly(Flattenable)]
96    /// struct A {
97    ///     a: i32,
98    /// }
99    /// ```
100    //
101    // Note re semver:
102    //
103    // We re-export derive-deftly's template engine, in the manner discussed by the d-a docs.
104    // See
105    //  https://docs.rs/derive-deftly/latest/derive_deftly/macro.define_derive_deftly.html#exporting-a-template-for-use-by-other-crates
106    //
107    // The semantic behaviour of the template *does* have semver implications.
108    export Flattenable for struct, expect items, beta_deftly:
109
110    ${impl $crate::Flattenable} {
111        fn has_field(s: &str) -> bool {
112            let fnames = $crate::flattenable_extract_fields::<'_, Self>();
113            IntoIterator::into_iter(fnames).any(|f| *f == s)
114
115        }
116    }
117
118    // Detect if flattenable_extract_fields panics
119    #[test]
120    fn $<flattenable_test_ ${snake_case $tname}>() {
121        // Using $ttype::has_field avoids writing out again
122        // the call to flattenable_extract_fields, with all its generics,
123        // and thereby ensures that we didn't have a mismatch that
124        // allows broken impls to slip through.
125        // (We know the type is at least similar because we go via the Flattenable impl.)
126        let _: bool = <$ttype as $crate::Flattenable>::has_field("");
127    }
128}
129pub use derive_deftly_template_Flattenable;
130
131/// Helper for flattening deserialisation, compatible with [`serde_ignored`]
132///
133/// A combination of two structs `T` and `U`.
134///
135/// The serde representation flattens both structs into a single, larger, struct.
136///
137/// Furthermore, unlike plain use of `#[serde(flatten)]`,
138/// `serde_ignored` will still detect fields which appear in serde input
139/// but which form part of neither `T` nor `U`.
140///
141/// `T` and `U` must both be [`Flattenable`].
142/// Usually that trait should be derived with
143/// the [`Flattenable macro`](derive_deftly_template_Flattenable).
144///
145/// If it's desired to combine more than two structs, `Flatten` can be nested.
146///
147/// # Limitations
148///
149/// Field name overlaps are not detected.
150/// Fields which appear in both structs
151/// will be processed as part of `T` during deserialization.
152/// They will be internally presented as duplicate fields during serialization,
153/// with the outcome depending on the data format implementation.
154///
155/// # Example
156///
157/// ```
158/// use serde::{Serialize, Deserialize};
159/// use derive_deftly::Deftly;
160/// use tor_config::{Flatten, derive_deftly_template_Flattenable};
161///
162/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
163/// #[derive_deftly(Flattenable)]
164/// struct A {
165///     a: i32,
166/// }
167///
168/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
169/// #[derive_deftly(Flattenable)]
170/// struct B {
171///     b: String,
172/// }
173///
174/// let combined: Flatten<A,B> = toml::from_str(r#"
175///     a = 42
176///     b = "hello"
177/// "#).unwrap();
178///
179/// assert_eq!(
180///    combined,
181///    Flatten(A { a: 42 }, B { b: "hello".into() }),
182/// );
183/// ```
184//
185// We derive Deftly on Flatten itself so we can use
186// derive_deftly_adhoc! to iterate over Flatten's two fields.
187// This avoids us accidentally (for example) checking T's fields for passing to U.
188#[derive(Deftly, Debug, Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq, Default)]
189#[derive_deftly_adhoc]
190#[allow(clippy::exhaustive_structs)]
191pub struct Flatten<T, U>(pub T, pub U);
192
193/// Types that can be used with [`Flatten`]
194///
195/// Usually, derived with
196/// the [`Flattenable derive-deftly macro`](derive_deftly_template_Flattenable).
197pub trait Flattenable {
198    /// Does this type have a field named `s` ?
199    fn has_field(f: &str) -> bool;
200}
201
202//========== local helper macros ==========
203
204/// Implement `deserialize_$what` as a call to `deserialize_any`.
205///
206/// `$args`, if provided, are any other formal arguments, not including the `Visitor`
207macro_rules! call_any { { $what:ident $( $args:tt )* } => { paste!{
208    fn [<deserialize_ $what>]<V>(self $( $args )*, visitor: V) -> Result<V::Value, Self::Error>
209    where
210        V: Visitor<'de>,
211    {
212        self.deserialize_any(visitor)
213    }
214} } }
215
216/// Implement most `deserialize_*` as calls to `deserialize_any`.
217///
218/// The exceptions are the ones we need to handle specially in any of our types,
219/// namely `any` itself and `struct`.
220macro_rules! call_any_for_rest { {} => {
221    call_any!(map);
222    call_any!(bool);
223    call_any!(byte_buf);
224    call_any!(bytes);
225    call_any!(char);
226    call_any!(f32);
227    call_any!(f64);
228    call_any!(i128);
229    call_any!(i16);
230    call_any!(i32);
231    call_any!(i64);
232    call_any!(i8);
233    call_any!(identifier);
234    call_any!(ignored_any);
235    call_any!(option);
236    call_any!(seq);
237    call_any!(str);
238    call_any!(string);
239    call_any!(u128);
240    call_any!(u16);
241    call_any!(u32);
242    call_any!(u64);
243    call_any!(u8);
244    call_any!(unit);
245
246    call_any!(enum, _: &'static str, _: FieldList);
247    call_any!(newtype_struct, _: &'static str );
248    call_any!(tuple, _: usize );
249    call_any!(tuple_struct, _: &'static str, _: usize );
250    call_any!(unit_struct, _: &'static str );
251} }
252
253//========== Implementations of Serialize and Flattenable ==========
254
255derive_deftly_adhoc! {
256    Flatten expect items:
257
258    impl<T, U> Serialize for Flatten<T, U>
259    where $( $ftype: Serialize, )
260    {
261        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262        where S: Serializer
263        {
264            /// version of outer `Flatten` containing references
265            ///
266            /// We give it the same name because the name is visible via serde
267            ///
268            /// The problems with `#[serde(flatten)]` don't apply to serialisation,
269            /// because we're not trying to track ignored fields.
270            /// But we can't just apply `#[serde(flatten)]` to `Flatten`
271            /// since it doesn't work with tuple structs.
272            #[derive(Serialize)]
273            struct Flatten<'r, T, U> {
274              $(
275                #[serde(flatten)]
276                $fpatname: &'r $ftype,
277              )
278            }
279
280            Flatten {
281              $(
282                $fpatname: &self.$fname,
283              )
284            }
285            .serialize(serializer)
286        }
287    }
288
289    /// `Flatten` may be nested
290    impl<T, U> Flattenable for Flatten<T, U>
291    where $( $ftype: Flattenable, )
292    {
293        fn has_field(f: &str) -> bool {
294            $(
295                $ftype::has_field(f)
296                    ||
297              )
298                false
299        }
300    }
301}
302
303//========== Deserialize implementation ==========
304
305/// The keys and values we are to direct to a particular child
306///
307/// See the module-level comment for the algorithm.
308#[derive(Default)]
309struct Portion(VecDeque<(String, Value)>);
310
311/// [`de::Visitor`] for `Flatten`
312struct FlattenVisitor<T, U>(PhantomData<(T, U)>);
313
314/// Wrapper for a field name, impls [`de::Deserializer`]
315struct Key(String);
316
317/// Type alias for reified error
318///
319/// [`serde_value::DeserializerError`] has one variant
320/// for each of the constructors of [`de::Error`].
321type FlattenError = serde_value::DeserializerError;
322
323//----- part 1: disassembly -----
324
325derive_deftly_adhoc! {
326    Flatten expect items:
327
328    // where constraint on the Deserialize impl
329    ${define FLATTENABLE $( $ftype: Deserialize<'de> + Flattenable, )}
330
331    impl<'de, T, U> Deserialize<'de> for Flatten<T, U>
332    where $FLATTENABLE
333    {
334        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
335        where D: Deserializer<'de>
336        {
337            deserializer.deserialize_map(FlattenVisitor(PhantomData))
338        }
339    }
340
341    impl<'de, T, U> Visitor<'de> for FlattenVisitor<T,U>
342    where $FLATTENABLE
343    {
344        type Value = Flatten<T, U>;
345
346        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347            write!(f, "map (for struct)")
348        }
349
350        fn visit_map<A>(self, mut mapa: A) -> Result<Self::Value, A::Error>
351        where A: MapAccess<'de>
352        {
353            // See the module-level comment for an explanation.
354
355            // $P is a local variable named after T/U: `p_t` or `p_u`, as appropriate
356            ${define P $<p_ $fname>}
357
358            ${for fields { let mut $P = Portion::default(); }}
359
360            #[allow(clippy::suspicious_else_formatting)] // this is the least bad layout
361            while let Some(k) = mapa.next_key::<String>()? {
362              $(
363                 if $ftype::has_field(&k) {
364                    let v: Value = mapa.next_value()?;
365                    $P.0.push_back((k, v));
366                    continue;
367                }
368                else
369              )
370                {
371                     let _: IgnoredAny = mapa.next_value()?;
372                }
373            }
374
375            Flatten::assemble( ${for fields { $P, }} )
376                .map_err(A::Error::custom)
377        }
378    }
379}
380
381//----- part 2: reassembly -----
382
383derive_deftly_adhoc! {
384    Flatten expect items:
385
386    impl<'de, T, U> Flatten<T, U>
387    where $( $ftype: Deserialize<'de>, )
388    {
389        /// Assemble a `Flatten` out of the partition of its keys and values
390        ///
391        /// Uses `Portion`'s `Deserializer` impl and T and U's `Deserialize`
392        fn assemble(
393          $(
394            $fpatname: Portion,
395          )
396        ) -> Result<Self, FlattenError> {
397            Ok(Flatten(
398              $(
399                $ftype::deserialize($fpatname)?,
400              )
401            ))
402        }
403    }
404}
405
406impl<'de> Deserializer<'de> for Portion {
407    type Error = FlattenError;
408
409    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
410    where
411        V: Visitor<'de>,
412    {
413        visitor.visit_map(self)
414    }
415
416    call_any!(struct, _: &'static str, _: FieldList);
417    call_any_for_rest!();
418}
419
420impl<'de> MapAccess<'de> for Portion {
421    type Error = FlattenError;
422
423    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
424    where
425        K: DeserializeSeed<'de>,
426    {
427        let Some(entry) = self.0.get_mut(0) else {
428            return Ok(None);
429        };
430        let k = mem::take(&mut entry.0);
431        let k: K::Value = seed.deserialize(Key(k))?;
432        Ok(Some(k))
433    }
434
435    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
436    where
437        V: DeserializeSeed<'de>,
438    {
439        let v = self
440            .0
441            .pop_front()
442            .expect("next_value called inappropriately")
443            .1;
444        let r = seed.deserialize(v)?;
445        Ok(r)
446    }
447}
448
449impl<'de> Deserializer<'de> for Key {
450    type Error = FlattenError;
451
452    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
453    where
454        V: Visitor<'de>,
455    {
456        visitor.visit_string(self.0)
457    }
458
459    call_any!(struct, _: &'static str, _: FieldList);
460    call_any_for_rest!();
461}
462
463//========== Field extractor ==========
464
465/// List of fields, appears in several APIs here
466type FieldList = &'static [&'static str];
467
468/// Stunt "data format" which we use for extracting fields for derived `Flattenable` impls
469///
470/// The field extraction works as follows:
471///  * We ask serde to deserialize `$ttype` from a `FieldExtractor`
472///  * We expect the serde-macro-generated `Deserialize` impl to call `deserialize_struct`
473///  * We return the list of fields to match up as an error
474struct FieldExtractor;
475
476/// Error resulting from successful operation of a [`FieldExtractor`]
477///
478/// Existence of this error is a *success*.
479/// Unexpected behaviour by the type's serde implementation causes panics, not errors.
480#[derive(Error, Debug)]
481#[error("Flattenable macro test gave error, so test passed successfully")]
482struct FieldExtractorSuccess(FieldList);
483
484/// Extract fields of a struct, as viewed by `serde`
485///
486/// # Performance
487///
488/// In release builds, is very fast - all the serde nonsense boils off.
489/// In debug builds, maybe a hundred instructions, so not ideal,
490/// but it is at least O(1) since it doesn't have any loops.
491///
492/// # STABILITY WARNING
493///
494/// This function is `pub` but it is `#[doc(hidden)]`.
495/// The only legitimate use is via the `Flattenable` macro.
496/// There are **NO SEMVER GUARANTEES**
497///
498/// # Panics
499///
500/// Will panic on types whose serde field list cannot be simply extracted via serde,
501/// which will include things that aren't named fields structs,
502/// might include types decorated with unusual serde annotations.
503pub fn flattenable_extract_fields<'de, T: Deserialize<'de>>() -> FieldList {
504    let notional_input = FieldExtractor;
505    let FieldExtractorSuccess(fields) = T::deserialize(notional_input)
506        .map(|_| ())
507        .expect_err("unexpected success deserializing from FieldExtractor!");
508    fields
509}
510
511impl de::Error for FieldExtractorSuccess {
512    fn custom<E>(e: E) -> Self
513    where
514        E: Display,
515    {
516        panic!("Flattenable macro test failed - some *other* serde error: {e}");
517    }
518}
519
520impl<'de> Deserializer<'de> for FieldExtractor {
521    type Error = FieldExtractorSuccess;
522
523    fn deserialize_struct<V>(
524        self,
525        _name: &'static str,
526        fields: FieldList,
527        _visitor: V,
528    ) -> Result<V::Value, Self::Error>
529    where
530        V: Visitor<'de>,
531    {
532        Err(FieldExtractorSuccess(fields))
533    }
534
535    fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
536    where
537        V: Visitor<'de>,
538    {
539        panic!("test failed: Flattennable misimplemented by macros!");
540    }
541
542    call_any_for_rest!();
543}
544
545//========== tests ==========
546
547#[cfg(test)]
548mod test {
549    // @@ begin test lint list maintained by maint/add_warning @@
550    #![allow(clippy::bool_assert_comparison)]
551    #![allow(clippy::clone_on_copy)]
552    #![allow(clippy::dbg_macro)]
553    #![allow(clippy::mixed_attributes_style)]
554    #![allow(clippy::print_stderr)]
555    #![allow(clippy::print_stdout)]
556    #![allow(clippy::single_char_pattern)]
557    #![allow(clippy::unwrap_used)]
558    #![allow(clippy::unchecked_time_subtraction)]
559    #![allow(clippy::useless_vec)]
560    #![allow(clippy::needless_pass_by_value)]
561    #![allow(clippy::string_slice)] // See arti#2571
562    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
563    use super::*;
564
565    use std::collections::HashMap;
566
567    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
568    #[derive_deftly(Flattenable)]
569    struct A {
570        a: i32,
571        m: HashMap<String, String>,
572    }
573
574    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
575    #[derive_deftly(Flattenable)]
576    struct B {
577        b: i32,
578        v: Vec<String>,
579    }
580
581    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
582    #[derive_deftly(Flattenable)]
583    struct C {
584        c: HashMap<String, String>,
585    }
586
587    const TEST_INPUT: &str = r#"
588        a = 42
589
590        m.one = "unum"
591        m.two = "bis"
592
593        b = 99
594        v = ["hi", "ho"]
595
596        spurious = 66
597
598        c.zed = "final"
599    "#;
600
601    fn test_input() -> toml::Value {
602        toml::from_str(TEST_INPUT).unwrap()
603    }
604    fn simply<'de, T: Deserialize<'de>>() -> T {
605        test_input().try_into().unwrap()
606    }
607    fn with_ignored<'de, T: Deserialize<'de>>() -> (T, Vec<String>) {
608        let mut ignored = vec![];
609        let f = serde_ignored::deserialize(
610            test_input(), //
611            |path| ignored.push(path.to_string()),
612        )
613        .unwrap();
614        (f, ignored)
615    }
616
617    #[test]
618    fn plain() {
619        let f: Flatten<A, B> = test_input().try_into().unwrap();
620        assert_eq!(f, Flatten(simply(), simply()));
621    }
622
623    #[test]
624    fn ignored() {
625        let (f, ignored) = with_ignored::<Flatten<A, B>>();
626        assert_eq!(f, simply());
627        assert_eq!(ignored, ["c", "spurious"]);
628    }
629
630    #[test]
631    fn nested() {
632        let (f, ignored) = with_ignored::<Flatten<A, Flatten<B, C>>>();
633        assert_eq!(f, simply());
634        assert_eq!(ignored, ["spurious"]);
635    }
636
637    #[test]
638    fn ser() {
639        let f: Flatten<A, Flatten<B, C>> = simply();
640
641        assert_eq!(
642            serde_json::to_value(f).unwrap(),
643            serde_json::json!({
644                "a": 42,
645                "m": {
646                    "one": "unum",
647                    "two": "bis"
648                },
649                "b": 99,
650                "v": [
651                    "hi",
652                    "ho"
653                ],
654                "c": {
655                    "zed": "final"
656                }
657            }),
658        );
659    }
660
661    /// This function exists only so we can disassemble it.
662    ///
663    /// To see what the result looks like in a release build:
664    ///
665    ///  * `RUSTFLAGS=-g cargo test -p tor-config --all-features --locked --release -- --nocapture flattenable_extract_fields_a_test`
666    ///  * Observe the binary that's run, eg `Running unittests src/lib.rs (target/release/deps/tor_config-d4c4f29c45a0a3f9)`
667    ///  * Disassemble it `objdump -d target/release/deps/tor_config-d4c4f29c45a0a3f9`
668    ///  * Search for this function: `less +/'28flattenable_extract_fields_a.*:'`
669    ///
670    /// At the time of writing, the result is three instructions:
671    /// load the address of the list, load a register with the constant 2 (the length),
672    /// return.
673    fn flattenable_extract_fields_a() -> FieldList {
674        flattenable_extract_fields::<'_, A>()
675    }
676
677    #[test]
678    fn flattenable_extract_fields_a_test() {
679        use std::hint::black_box;
680        let f: fn() -> _ = black_box(flattenable_extract_fields_a);
681        eprintln!("{:?}", f());
682    }
683}