Skip to main content

sip_header/
macros.rs

1/// Generates a non-exhaustive enum mapping Rust variants to canonical protocol strings.
2///
3/// Produces: enum definition, `ALL` const, `as_str()`, `Display`, `AsRef<str>`,
4/// and `FromStr`. `FromStr` uses `eq_ignore_ascii_case` — appropriate for
5/// user-facing catalog types (header names, variable names) where input may
6/// come from config files. Wire protocol state types use hand-written strict
7/// `FromStr` instead.
8///
9/// Two error forms:
10///
11/// - `error_type: ParseMyEnumError,` — the error newtype is defined separately
12///   by the caller as `struct ParseMyEnumError(pub String)`.
13/// - `error_type: ParseMyEnumError => "unknown my value",` — the newtype, its
14///   `Display` (`"unknown my value: {input}"`), and `std::error::Error` are
15///   generated.
16///
17/// An optional leading `tests_mod: my_enum_tests,` generates a `#[cfg(test)]`
18/// module with round-trip, case-insensitivity, `Display`, and unknown-input
19/// tests over `ALL` (requires `PartialEq` on the error type).
20///
21/// # Example
22///
23/// ```ignore
24/// define_header_enum! {
25///     tests_mod: my_enum_tests,
26///     error_type: ParseMyEnumError => "unknown my value",
27///     /// Doc comment for the enum.
28///     pub enum MyEnum {
29///         Foo => "foo-wire",
30///         Bar => "bar-wire",
31///     }
32/// }
33/// ```
34#[macro_export]
35macro_rules! define_header_enum {
36    (
37        $(tests_mod: $tests_mod:ident,)?
38        error_type: $Err:ident $(=> $err_msg:literal)?,
39        $(#[$enum_meta:meta])*
40        $vis:vis enum $Name:ident {
41            $(
42                $(#[$var_meta:meta])*
43                $variant:ident => $wire:literal
44            ),+ $(,)?
45        }
46    ) => {
47        $(
48            #[doc = concat!("Error for an unrecognized value; displays as `", $err_msg, ": <input>`.")]
49            #[derive(Debug, Clone, PartialEq, Eq)]
50            $vis struct $Err(pub String);
51
52            impl std::fmt::Display for $Err {
53                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54                    write!(f, concat!($err_msg, ": {}"), self.0)
55                }
56            }
57
58            impl std::error::Error for $Err {}
59        )?
60
61        $(#[$enum_meta])*
62        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64        #[non_exhaustive]
65        #[allow(missing_docs)]
66        $vis enum $Name {
67            $(
68                $(#[$var_meta])*
69                $variant,
70            )+
71        }
72
73        impl $Name {
74            /// All variants, in declaration order (respecting any `#[cfg]` attributes).
75            // allow(unused_doc_comments): variant doc attrs are propagated onto
76            // array elements so that #[cfg] attrs also propagate; the doc attrs
77            // are harmless noise here. Same pattern in as_str/from_str below.
78            #[allow(unused_doc_comments)]
79            pub const ALL: &'static [Self] = &[
80                $(
81                    $(#[$var_meta])*
82                    $Name::$variant,
83                )+
84            ];
85
86            /// Canonical protocol string.
87            #[allow(unused_doc_comments)]
88            pub fn as_str(&self) -> &'static str {
89                match self {
90                    $(
91                        $(#[$var_meta])*
92                        $Name::$variant => $wire,
93                    )+
94                }
95            }
96        }
97
98        impl std::fmt::Display for $Name {
99            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100                f.write_str(self.as_str())
101            }
102        }
103
104        impl AsRef<str> for $Name {
105            fn as_ref(&self) -> &str {
106                self.as_str()
107            }
108        }
109
110        impl std::str::FromStr for $Name {
111            type Err = $Err;
112
113            #[allow(unused_doc_comments)]
114            fn from_str(s: &str) -> Result<Self, Self::Err> {
115                $(
116                    $(#[$var_meta])*
117                    if s.eq_ignore_ascii_case($wire) {
118                        return Ok($Name::$variant);
119                    }
120                )+
121                Err($Err(s.to_string()))
122            }
123        }
124
125        $(
126            #[cfg(test)]
127            mod $tests_mod {
128                use super::{$Err, $Name};
129
130                #[test]
131                fn round_trip() {
132                    for v in $Name::ALL {
133                        assert_eq!(v.as_str().parse::<$Name>(), Ok(*v));
134                    }
135                }
136
137                #[test]
138                fn case_insensitive() {
139                    for v in $Name::ALL {
140                        assert_eq!(v.as_str().to_lowercase().parse::<$Name>(), Ok(*v));
141                        assert_eq!(v.as_str().to_uppercase().parse::<$Name>(), Ok(*v));
142                    }
143                }
144
145                #[test]
146                fn display_matches_as_str() {
147                    for v in $Name::ALL {
148                        assert_eq!(v.to_string(), v.as_str());
149                    }
150                }
151
152                #[test]
153                fn unknown_input_err() {
154                    let input = "\u{0}no-such-value\u{0}";
155                    assert_eq!(input.parse::<$Name>(), Err($Err(input.to_string())));
156                }
157            }
158        )?
159    };
160}
161
162/// Implements `FromStr` by delegating to an inherent `parse(&str)` method.
163macro_rules! impl_from_str_via_parse {
164    ($Type:ty, $Err:ty) => {
165        impl std::str::FromStr for $Type {
166            type Err = $Err;
167
168            fn from_str(s: &str) -> Result<Self, Self::Err> {
169                Self::parse(s)
170            }
171        }
172    };
173}
174
175#[cfg(test)]
176mod tests {
177    define_header_enum! {
178        tests_mod: test_enum_generated,
179        error_type: ParseTestEnumError => "unknown test value",
180        /// Exercises generated error newtype, `ALL`, and test module.
181        pub(crate) enum TestEnum {
182            /// `Foo-Wire`.
183            Foo => "Foo-Wire",
184            /// `Bar-Wire`.
185            Bar => "Bar-Wire",
186            /// `Draft-Wire`.
187            #[cfg(feature = "draft")]
188            Draft => "Draft-Wire",
189        }
190    }
191
192    /// Hand-written error for the old-form invocation.
193    #[derive(Debug, Clone, PartialEq, Eq)]
194    pub(crate) struct ParseOldEnumError(pub String);
195
196    impl std::fmt::Display for ParseOldEnumError {
197        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198            write!(f, "unknown old value: {}", self.0)
199        }
200    }
201
202    impl std::error::Error for ParseOldEnumError {}
203
204    define_header_enum! {
205        error_type: ParseOldEnumError,
206        /// Old-form invocation stays source-compatible.
207        pub(crate) enum OldEnum {
208            /// `Old-Wire`.
209            One => "Old-Wire",
210        }
211    }
212
213    #[test]
214    fn generated_error_display() {
215        let e = ParseTestEnumError("nope".to_string());
216        assert_eq!(e.to_string(), "unknown test value: nope");
217    }
218
219    #[test]
220    fn all_const_respects_cfg() {
221        #[cfg(not(feature = "draft"))]
222        assert_eq!(TestEnum::ALL, &[TestEnum::Foo, TestEnum::Bar]);
223        #[cfg(feature = "draft")]
224        assert_eq!(
225            TestEnum::ALL,
226            &[TestEnum::Foo, TestEnum::Bar, TestEnum::Draft]
227        );
228    }
229
230    #[test]
231    fn old_form_generates_all() {
232        assert_eq!(OldEnum::ALL, &[OldEnum::One]);
233        assert_eq!("old-wire".parse::<OldEnum>(), Ok(OldEnum::One));
234    }
235}