tokio_postgres_macros/
lib.rs

1use std::cell::Cell;
2use quote2::{proc_macro2::{TokenStream, TokenTree, Literal}, quote, Quote};
3use syn::*;
4
5/// Implements `From<&Row>` trait for a struct, allowing direct conversion from a database row to the struct.
6///
7/// ## Example
8///
9/// ```rust
10/// use tokio_postgres_utils::FromRow;
11///
12/// #[derive(FromRow)]
13/// struct User {
14///     id: i32,
15///     name: String,
16/// }
17/// ```
18///
19/// Expand into:
20///
21/// ```
22/// impl From<&Row> for User {
23///     fn from(row: &Row) -> Self {
24///         Self {
25///             id: row.get("id"),
26///             name: row.get("name"),
27///         }
28///     }
29/// }
30/// ```
31#[proc_macro_derive(FromRow, attributes(column))]
32pub fn from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    let input = parse_macro_input!(input as DeriveInput);
34    let name = input.ident;
35    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
36
37    let body = quote(|tokens| match input.data {
38        Data::Struct(data) => match &data.fields {
39            Fields::Named(fields) => {
40                let body = quote(|tokens| {
41                    for field in &fields.named {
42                        if let Some(name) = &field.ident {
43                            quote!(tokens, { #name: });
44                            match column_attr(&field.attrs) {
45                                ColumnAttr::Flatten => {
46                                    quote!(tokens, {
47                                        ::std::convert::TryFrom::try_from(r).unwrap(),
48                                    });
49                                }
50                                ColumnAttr::Rename(rename) => {
51                                    quote!(tokens, {
52                                        r.get(#rename),
53                                    });
54                                },
55                                ColumnAttr::None => {
56                                    let raw_str = name.to_string();
57                                    quote!(tokens, {
58                                        r.get(#raw_str),
59                                    });
60                                }
61                                ColumnAttr::Skip => {
62                                    quote!(tokens, {
63                                        ::std::default::Default::default(),
64                                    });
65                                }
66                            }
67                        }
68                    }
69                });
70                quote!(tokens, {
71                    { #body }
72                });
73            }
74            Fields::Unnamed(fields) => {
75                let body = quote(|tokens| {
76                    for (i, _) in fields.unnamed.iter().enumerate() {
77                        let idx = Index::from(i);
78                        quote!(tokens, {
79                            r.get(#idx),
80                        });
81                    }
82                });
83                quote!(tokens, {
84                    (#body)
85                });
86            }
87            Fields::Unit => {}
88        },
89        Data::Enum(_) | Data::Union(_) => unimplemented!(),
90    });
91
92    let mut tokens = TokenStream::new();
93    quote!(tokens, {
94        impl #impl_generics ::std::convert::From<&tokio_postgres::Row> for #name #ty_generics #where_clause {
95            #[inline]
96            fn from(r: &tokio_postgres::Row) -> Self {
97                Self #body
98            }
99        }
100    });
101    tokens.into()
102}
103
104/// Implements the `TryFrom<&Row>` trait for a struct
105#[proc_macro_derive(TryFromRow, attributes(column))]
106pub fn try_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
107    let input = parse_macro_input!(input as DeriveInput);
108    let name = input.ident;
109    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
110
111    let has_attr = Cell::new(false);
112
113    let body = quote(|tokens| match input.data {
114        Data::Struct(data) => match &data.fields {
115            Fields::Named(fields) => {
116                let body = quote(|tokens| {
117                    for field in &fields.named {
118                        if let Some(name) = &field.ident {
119                            quote!(tokens, { #name: });
120                            match column_attr(&field.attrs) {
121                                ColumnAttr::Flatten => {
122                                    has_attr.set(true);
123                                    quote!(tokens, {
124                                        ::std::convert::TryFrom::try_from(r)?,
125                                    });
126                                }
127                                ColumnAttr::Rename(rename) => {
128                                    quote!(tokens, {
129                                        r.try_get(#rename)?,
130                                    });
131                                },
132                                ColumnAttr::None => {
133                                    let raw_str = name.to_string();
134                                    quote!(tokens, {
135                                        r.try_get(#raw_str)?,
136                                    });
137                                }
138                                ColumnAttr::Skip => {
139                                    quote!(tokens, {
140                                        ::std::default::Default::default(),
141                                    });
142                                }
143                            }
144                        }
145                    }
146                });
147                quote!(tokens, {
148                    { #body }
149                });
150            }
151            Fields::Unnamed(fields) => {
152                let body = quote(|tokens| {
153                    for (i, _) in fields.unnamed.iter().enumerate() {
154                        let idx = Index::from(i);
155                        quote!(tokens, {
156                            r.try_get(#idx)?,
157                        });
158                    }
159                });
160                quote!(tokens, {
161                    (#body)
162                });
163            }
164            Fields::Unit => {}
165        },
166        Data::Enum(_) | Data::Union(_) => unimplemented!(),
167    });
168
169    let err_ty = quote(|t| {
170        if has_attr.get() {
171            quote!(t, { ::std::boxed::Box<dyn ::std::error::Error + ::std::marker::Send + ::std::marker::Sync> });
172        } else {
173            quote!(t, { tokio_postgres::Error });
174        }
175    });
176
177    let mut tokens = TokenStream::new();
178    quote!(tokens, {
179        impl #impl_generics ::std::convert::TryFrom<&tokio_postgres::Row> for #name #ty_generics #where_clause {
180            #[inline]
181            fn try_from(r: &tokio_postgres::Row) -> ::std::result::Result<Self, Self::Error> {
182                Ok(Self #body)
183            }
184            type Error = #err_ty;
185        }
186    });
187    tokens.into()
188}
189
190
191
192enum ColumnAttr {
193    Skip,
194    Flatten,
195    None,
196    Rename(Literal),
197}
198
199fn column_attr(attrs: &[Attribute]) -> ColumnAttr {
200    attrs
201        .iter()
202        .find_map(|attr| {
203            if let Meta::List(MetaList { path, tokens, .. }) = &attr.meta {
204                if path.segments.first()?.ident == "column" {
205                    let mut tokens = tokens.clone().into_iter();
206                    match tokens.next()?.to_string().as_str() {
207                        "skip" => return Some(ColumnAttr::Skip),
208                        "flatten" => return Some(ColumnAttr::Flatten),
209                        "rename" => {
210                            if matches!(tokens.next()?, TokenTree::Punct(p) if p.as_char() == '=') {
211                                if let TokenTree::Literal(lit) = tokens.next()? {
212                                    return Some(ColumnAttr::Rename(lit));
213                                }
214                            }
215                        }
216                        _ => {}
217                    }
218                }
219            }
220            None
221        })
222        .unwrap_or(ColumnAttr::None)
223}