scyllax_macros_core/
entity.rs

1use darling::{ast, util, FromDeriveInput, FromField};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{spanned::Spanned, DeriveInput, Ident, Type};
5
6#[derive(Debug, PartialEq, FromField)]
7#[darling(attributes(entity), and_then = EntityDeriveColumn::set_name)]
8pub struct EntityDeriveColumn {
9    pub ident: Option<Ident>,
10    pub ty: Type,
11    /// this is required to be set as optional and default
12    /// because we set it manually in [`EntityDeriveColumn::set_name`]. It will NEVER be None.
13    #[darling(default)]
14    pub name: Option<String>,
15
16    #[darling(default)]
17    pub counter: bool,
18    #[darling(default)]
19    pub primary_key: bool,
20    #[darling(default)]
21    pub rename: Option<String>,
22}
23
24impl EntityDeriveColumn {
25    /// Sets the name of the column, considering the rename attribute.
26    fn set_name(self) -> darling::Result<Self> {
27        let Self {
28            ident,
29            ty,
30            counter,
31            primary_key,
32            rename,
33            ..
34        } = self;
35
36        let name = rename
37            .clone()
38            .or_else(|| ident.as_ref().map(|i| i.to_string()))
39            .map(|i| format!(r##""{i}""##));
40
41        Ok(Self {
42            ident,
43            ty,
44            name,
45
46            counter,
47            primary_key,
48            rename,
49        })
50    }
51}
52
53#[derive(Debug, PartialEq, FromDeriveInput)]
54#[darling(attributes(entity), supports(struct_named))]
55pub struct EntityDerive {
56    pub ident: Ident,
57    pub data: ast::Data<util::Ignored, EntityDeriveColumn>,
58}
59
60impl ToTokens for EntityDerive {
61    fn to_tokens(&self, tokens: &mut TokenStream) {
62        let EntityDerive {
63            ref ident,
64            ref data,
65        } = *self;
66
67        let fields = data
68            .as_ref()
69            .take_struct()
70            .expect("Should never be enum")
71            .fields;
72
73        // validate counters
74        for field in fields.iter().filter(|f| f.counter) {
75            if let syn::Type::Path(path) = &field.ty {
76                if let Some(ident) = path.path.get_ident() {
77                    if ident != "scylla::frame::value::Counter" {
78                        tokens.extend(
79                            syn::Error::new(
80                                field.ident.span(),
81                                "Counter fields must be of type `scylla::frame::value::Counter`",
82                            )
83                            .to_compile_error(),
84                        );
85
86                        return;
87                    }
88                }
89            } else {
90                tokens.extend(
91                    syn::Error::new(
92                        field.ident.span(),
93                        "Counter fields must be of type `scylla::frame::value::Counter`",
94                    )
95                    .to_compile_error(),
96                );
97
98                return;
99            }
100        }
101
102        let keys: Vec<TokenStream> = fields
103            .iter()
104            .map(|f| {
105                let name = &f.name;
106                quote!(#name.to_string())
107            })
108            .collect();
109
110        let primary_keys: Vec<TokenStream> = fields
111            .iter()
112            .filter(|f| f.primary_key)
113            .map(|f| {
114                let name = &f.name;
115                quote!(#name.to_string())
116            })
117            .collect();
118
119        let spat = quote! {
120            impl scyllax::prelude::EntityExt<#ident> for #ident {
121                fn keys() -> Vec<String> {
122                    vec![#(#keys),*]
123                }
124
125                fn pks() -> Vec<String> {
126                    vec![#(#primary_keys),*]
127                }
128            }
129        };
130
131        tokens.extend(spat);
132    }
133}
134
135/// Attribute expand
136/// Just adds the dervie macro to the struct.
137pub fn expand(input: TokenStream) -> TokenStream {
138    let input: DeriveInput = match syn::parse2(input.clone()) {
139        Ok(it) => it,
140        Err(e) => return e.to_compile_error(),
141    };
142
143    match EntityDerive::from_derive_input(&input) {
144        Ok(e) => e.into_token_stream(),
145        Err(e) => e.write_errors(),
146    }
147}
148
149/// Expands the shorthand attribute
150pub fn expand_attr(_args: TokenStream, input: TokenStream) -> TokenStream {
151    quote! {
152        #[derive(
153            Clone,
154            Debug,
155            PartialEq,
156            scylla::SerializeRow,
157            scylla_reexports::FromRow,
158            scylla_reexports::ValueList,
159            scyllax::prelude::Entity
160        )]
161        #input
162    }
163}