tui_theme_builder_derive/
lib.rs

1use core::panic;
2use proc_macro::TokenStream;
3use proc_macro2::{Punct, Spacing, TokenStream as TokenStream2, TokenTree};
4use quote::quote;
5use syn::{parse::ParseStream, parse_macro_input, Attribute, Data, DeriveInput, Fields, Ident};
6
7/// # Panics
8/// - Panics if derive is not attached to a struct
9/// - Panics if no `context` attribute is found
10#[allow(clippy::too_many_lines)]
11#[proc_macro_derive(ThemeBuilder, attributes(context, builder, style, border_type))]
12pub fn derive_theme_builder(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14
15    let struct_name = &input.ident;
16
17    let Data::Struct(data) = &input.data else {
18        panic!("derive must be attached to a struct");
19    };
20
21    let builder_attr = extract_builder_attribute(&input.attrs);
22    let Some(builder_attr) = builder_attr else {
23        panic!("no `context` attribute found on struct");
24    };
25    let context_name = process_builder_struct_attribute(builder_attr);
26    let Some(context_name) = context_name else {
27        panic!("no `context` field found in builder annotation");
28    };
29
30    let Fields::Named(fields) = &data.fields else {
31        panic!("expected named fields, got {:?}", &data.fields)
32    };
33
34    let mut field_constructors: Vec<TokenStream2> = Vec::new();
35
36    for field in &fields.named {
37        let field_name = field.ident.as_ref().unwrap();
38        let field_type = &field.ty;
39
40        let mut field_constructor = quote! {};
41
42        // Handle `Style` tagged fields.
43        let attr = extract_style_attribute(&field.attrs);
44        if let Some(attr) = attr {
45            let style_values = process_style_attribute(attr);
46
47            field_constructor.extend(quote! {
48                #field_name: ratatui::style::Style::default()
49            });
50
51            if let Some(foreground_color) = style_values.foreground {
52                field_constructor.extend(quote! {
53                    .fg(context.#foreground_color.clone().into())
54                });
55            }
56
57            if let Some(background_color) = style_values.background {
58                field_constructor.extend(quote! {
59                    .bg(context.#background_color.clone().into())
60                });
61            }
62
63            if style_values.bold.is_some() {
64                field_constructor.extend(quote! {
65                    .add_modifier(ratatui::style::Modifier::BOLD)
66                });
67            }
68
69            if style_values.dim.is_some() {
70                field_constructor.extend(quote! {
71                    .add_modifier(ratatui::style::Modifier::DIM)
72                });
73            }
74
75            if style_values.italic.is_some() {
76                field_constructor.extend(quote! {
77                    .add_modifier(ratatui::style::Modifier::ITALIC)
78                });
79            }
80
81            if style_values.underlined.is_some() {
82                field_constructor.extend(quote! {
83                    .add_modifier(ratatui::style::Modifier::UNDERLINED)
84                });
85            }
86
87            if style_values.slow_blink.is_some() {
88                field_constructor.extend(quote! {
89                    .add_modifier(ratatui::style::Modifier::SLOW_BLINK)
90                });
91            }
92
93            if style_values.rapid_blink.is_some() {
94                field_constructor.extend(quote! {
95                    .add_modifier(ratatui::style::Modifier::RAPID_BLINK)
96                });
97            }
98
99            if style_values.reversed.is_some() {
100                field_constructor.extend(quote! {
101                    .add_modifier(ratatui::style::Modifier::REVERSED)
102                });
103            }
104
105            if style_values.hidden.is_some() {
106                field_constructor.extend(quote! {
107                    .add_modifier(ratatui::style::Modifier::HIDDEN)
108                });
109            }
110
111            if style_values.crossed_out.is_some() {
112                field_constructor.extend(quote! {
113                    .add_modifier(ratatui::style::Modifier::CROSSED_OUT)
114                });
115            }
116
117            field_constructors.push(field_constructor);
118            continue;
119        }
120
121        // Handle `border_type` tagged fields.
122        let attr = extract_border_type_attribute(&field.attrs);
123        if let Some(attr) = attr {
124            let border_type_value = process_border_type_attribute(attr);
125            let Some(border_type_value) = border_type_value else {
126                panic!("missing value in `border_type` on field `{:?}`", field_name);
127            };
128
129            match border_type_value {
130                BorderTypeAttribute::Variant(variant) => {
131                    let variant_str = variant.to_string();
132                    let variant_ident = match variant_str.as_str() {
133                        "Plain" | "plain" => quote! { Plain },
134                        "Rounded" | "rounded" => quote! { Rounded },
135                        "Double" | "double" => quote! { Double },
136                        "Thick" | "thick" => quote! { Thick },
137                        "QuadrantInside" | "quadrant_inside" => quote! { QuadrantInside },
138                        "QuadrantOutside" | "quadrant_outside" => quote! { QuadrantOutside },
139                        _ => panic!("unknown BorderType variant: {}", variant_str),
140                    };
141                    field_constructor.extend(quote! {
142                        #field_name: ratatui::widgets::BorderType::#variant_ident
143                    });
144                }
145                BorderTypeAttribute::Value(value) => {
146                    field_constructor.extend(quote! {
147                        #field_name: context.#value.clone()
148                    });
149                }
150            }
151
152            field_constructors.push(field_constructor);
153            continue;
154        }
155
156        // Handle `builder` tagged fields.
157        let attr = extract_builder_attribute(&field.attrs);
158        if let Some(attr) = attr {
159            let value = process_builder_field_attribute(attr);
160            let Some(value) = value else {
161                panic!("missing value in `builder` on field `{:?}`", field_name);
162            };
163
164            match value {
165                BuilderFieldAttribute::Value(value) => {
166                    field_constructor.extend(quote! {
167                            #field_name: context.#value.clone()
168                    });
169                }
170                BuilderFieldAttribute::Default => {
171                    field_constructor.extend(quote! {
172                        #field_name: <#field_type>::default()
173                    });
174                }
175            }
176
177            field_constructors.push(field_constructor);
178            continue;
179        }
180
181        // Handle untagged fields.
182        field_constructor.extend(quote! {
183                #field_name: #field_type::build(context)
184        });
185
186        field_constructors.push(field_constructor);
187    }
188
189    let implementation = quote! {
190        impl tui_theme_builder::ThemeBuilder for #struct_name {
191            type Context = #context_name;
192            fn build(context: &#context_name) -> Self {
193                Self {
194                    #(#field_constructors),*
195                }
196            }
197        }
198    };
199
200    TokenStream::from(implementation)
201}
202
203/// A helper method to extract the `builder` attribute in a list of attributes.
204fn extract_builder_attribute(attrs: &[Attribute]) -> Option<&Attribute> {
205    attrs.iter().find(|attr| attr.path().is_ident("builder"))
206}
207
208/// A helper method that processes a field with builder annotation.
209fn process_builder_field_attribute(attr: &Attribute) -> Option<BuilderFieldAttribute> {
210    let mut attribute: Option<BuilderFieldAttribute> = None;
211
212    let _ = attr.parse_nested_meta(|meta| {
213        if meta.path.is_ident("value") {
214            let value = meta.value()?;
215            let value = extract_metadata_stream(value)?;
216            if value.to_string() == "default" {
217                attribute = Some(BuilderFieldAttribute::Default);
218            } else {
219                attribute = Some(BuilderFieldAttribute::Value(value));
220            }
221            Ok(())
222        } else {
223            Err(meta.error("unsupported attribute"))
224        }
225    });
226
227    attribute
228}
229
230enum BuilderFieldAttribute {
231    Value(TokenStream2),
232    Default,
233}
234
235/// Helper to that process the builder attribute of a struct and returns the
236/// ident of the context type.
237fn process_builder_struct_attribute(attr: &Attribute) -> Option<Ident> {
238    let mut context: Option<Ident> = None;
239
240    let _ = attr.parse_nested_meta(|meta| {
241        if meta.path.is_ident("context") {
242            let value = meta.value()?;
243            let ident: syn::Ident = value.parse()?;
244            context = Some(ident);
245            Ok(())
246        } else {
247            Err(meta.error("unsupported attribute"))
248        }
249    });
250
251    context
252}
253
254/// A helper method to extract the `style` attribute in a list of attributes.
255fn extract_style_attribute(attrs: &[Attribute]) -> Option<&Attribute> {
256    attrs.iter().find(|attr| attr.path().is_ident("style"))
257}
258
259/// A helper method that processes a field with style annotation.
260fn process_style_attribute(attr: &Attribute) -> StyleValues {
261    let mut foreground: Option<TokenStream2> = None;
262    let mut background: Option<TokenStream2> = None;
263    let mut bold: Option<bool> = None;
264    let mut dim: Option<bool> = None;
265    let mut italic: Option<bool> = None;
266    let mut underlined: Option<bool> = None;
267    let mut slow_blink: Option<bool> = None;
268    let mut rapid_blink: Option<bool> = None;
269    let mut reversed: Option<bool> = None;
270    let mut hidden: Option<bool> = None;
271    let mut crossed_out: Option<bool> = None;
272
273    let _ = attr.parse_nested_meta(|meta| {
274        if let Some(ident) = meta.path.get_ident() {
275            match ident.to_string().as_str() {
276                "bold" => bold = Some(true),
277                "dim" => dim = Some(true),
278                "italic" => italic = Some(true),
279                "underlined" => underlined = Some(true),
280                "slow_blink" => slow_blink = Some(true),
281                "rapid_blink" => rapid_blink = Some(true),
282                "reversed" => reversed = Some(true),
283                "hidden" => hidden = Some(true),
284                "crossed_out" => crossed_out = Some(true),
285                "fg" | "foreground" => {
286                    let value = meta.value()?;
287                    let ident = extract_metadata_stream(value).unwrap();
288                    foreground = Some(ident);
289                }
290                "bg" | "background" => {
291                    let value = meta.value()?;
292                    let ident = extract_metadata_stream(value)?;
293                    background = Some(ident);
294                }
295                _ => {}
296            }
297        }
298
299        Ok(())
300    });
301
302    StyleValues {
303        foreground,
304        background,
305        bold,
306        dim,
307        italic,
308        underlined,
309        slow_blink,
310        rapid_blink,
311        reversed,
312        hidden,
313        crossed_out,
314    }
315}
316
317struct StyleValues {
318    foreground: Option<TokenStream2>,
319    background: Option<TokenStream2>,
320    bold: Option<bool>,
321    dim: Option<bool>,
322    italic: Option<bool>,
323    underlined: Option<bool>,
324    slow_blink: Option<bool>,
325    rapid_blink: Option<bool>,
326    reversed: Option<bool>,
327    hidden: Option<bool>,
328    crossed_out: Option<bool>,
329}
330
331/// A helper method to extract the `border_type` attribute in a list of attributes.
332fn extract_border_type_attribute(attrs: &[Attribute]) -> Option<&Attribute> {
333    attrs
334        .iter()
335        .find(|attr| attr.path().is_ident("border_type"))
336}
337
338/// A helper method that processes a field with border_type annotation.
339fn process_border_type_attribute(attr: &Attribute) -> Option<BorderTypeAttribute> {
340    let mut attribute: Option<BorderTypeAttribute> = None;
341
342    let _ = attr.parse_nested_meta(|meta| {
343        if meta.path.is_ident("value") {
344            // Handle #[border_type(value = context.field)]
345            let value = meta.value()?;
346            let value = extract_metadata_stream(value)?;
347            attribute = Some(BorderTypeAttribute::Value(value));
348            Ok(())
349        } else if let Some(ident) = meta.path.get_ident() {
350            // Handle #[border_type(Rounded)] or #[border_type(Plain)]
351            attribute = Some(BorderTypeAttribute::Variant(ident.clone()));
352            Ok(())
353        } else {
354            Err(meta.error("unsupported border_type attribute"))
355        }
356    });
357
358    attribute
359}
360
361enum BorderTypeAttribute {
362    /// A direct variant like `Rounded`, `Plain`, etc.
363    Variant(Ident),
364    /// A reference to a context field like `theme.border_type`
365    Value(TokenStream2),
366}
367
368/// A helper method that parses a `ParseStream` to a `TokenStream`. It is necessary
369/// to handle nested fields such as `#[builder(value=footer.hide)]`
370fn extract_metadata_stream(input: ParseStream) -> Result<TokenStream2, syn::Error> {
371    let mut tokens = TokenStream2::new();
372    while !input.is_empty() {
373        if input.peek(Ident) {
374            let ident: Ident = input.parse()?;
375            tokens.extend(Some(TokenTree::Ident(ident)));
376        } else if input.peek(syn::Token![.]) {
377            let _dot: syn::Token![.] = input.parse()?;
378            tokens.extend(Some(TokenTree::Punct(Punct::new('.', Spacing::Alone))));
379        } else if input.peek(syn::Token![,]) {
380            break;
381        } else {
382            return Err(input.error(format!(
383                "expected an identifier or a dot, but got {input:?}",
384            )));
385        }
386    }
387
388    Ok(tokens)
389}