zero_postgres_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
4
5/// Derive macro for `FromRow` trait.
6///
7/// Generates an implementation that matches column names to struct fields.
8///
9/// # Example
10///
11/// ```ignore
12/// #[derive(FromRow)]
13/// struct User {
14///     name: String,
15///     age: i32,
16/// }
17/// ```
18///
19/// # Strict Mode
20///
21/// By default, unknown columns are silently skipped. Use `#[from_row(strict)]`
22/// to error on unknown columns:
23///
24/// ```ignore
25/// #[derive(FromRow)]
26/// #[from_row(strict)]
27/// struct User {
28///     name: String,
29///     age: i32,
30/// }
31/// ```
32#[proc_macro_derive(FromRow, attributes(from_row))]
33pub fn derive_from_row(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as DeriveInput);
35
36    let name = &input.ident;
37    let generics = &input.generics;
38    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40    // Check for #[from_row(strict)]
41    let strict = input.attrs.iter().any(|attr| {
42        if !attr.path().is_ident("from_row") {
43            return false;
44        }
45        match &attr.meta {
46            Meta::List(list) => list.tokens.to_string().contains("strict"),
47            _ => false,
48        }
49    });
50
51    let fields = match &input.data {
52        Data::Struct(data) => match &data.fields {
53            Fields::Named(fields) => &fields.named,
54            _ => panic!("FromRow only supports structs with named fields"),
55        },
56        _ => panic!("FromRow only supports structs"),
57    };
58
59    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
60    let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
61    let field_name_strs: Vec<_> = field_names.iter().map(|n| n.to_string()).collect();
62
63    // Generate MaybeUninit declarations
64    let uninit_decls = field_names
65        .iter()
66        .zip(field_types.iter())
67        .map(|(name, ty)| {
68            quote! {
69                let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
70            }
71        });
72
73    // Generate set flags
74    let set_flag_names: Vec<_> = field_names
75        .iter()
76        .map(|n| syn::Ident::new(&format!("{}_set", n), n.span()))
77        .collect();
78
79    let set_flag_decls = set_flag_names.iter().map(|flag| {
80        quote! { let mut #flag = false; }
81    });
82
83    // Generate match arms for text decoding
84    let match_arms_text = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
85        quote! {
86            #name_str => {
87                let __val: #ty = match __value {
88                    None => ::zero_postgres::conversion::FromWireValue::from_null()?,
89                    Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_text(__field.type_oid(), __bytes)?,
90                };
91                #name.write(__val);
92                #flag = true;
93            }
94        }
95    });
96
97    // Generate match arms for binary decoding
98    let match_arms_binary = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
99        quote! {
100            #name_str => {
101                let __val: #ty = match __value {
102                    None => ::zero_postgres::conversion::FromWireValue::from_null()?,
103                    Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_binary(__field.type_oid(), __bytes)?,
104                };
105                #name.write(__val);
106                #flag = true;
107            }
108        }
109    });
110
111    // Generate fallback arm based on strict mode
112    let fallback_arm = if strict {
113        quote! {
114            __unknown => {
115                return Err(::zero_postgres::Error::Decode(format!("unknown column: {}", __unknown)));
116            }
117        }
118    } else {
119        quote! {
120            _ => {
121                // Skip unknown column
122            }
123        }
124    };
125
126    // Generate initialization checks
127    let init_checks = field_names
128        .iter()
129        .zip(set_flag_names.iter())
130        .zip(field_name_strs.iter())
131        .map(|((_name, flag), name_str)| {
132            quote! {
133                if !#flag {
134                    return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
135                }
136            }
137        });
138
139    // Generate struct construction
140    let field_inits = field_names.iter().map(|name| {
141        quote! {
142            #name: unsafe { #name.assume_init() }
143        }
144    });
145
146    // Clone iterators for text implementation
147    let uninit_decls_text = uninit_decls.clone();
148    let set_flag_decls_text = set_flag_decls.clone();
149    let init_checks_text = init_checks.clone();
150    let field_inits_text = field_inits.clone();
151
152    // Clone for binary implementation
153    let uninit_decls_binary = field_names
154        .iter()
155        .zip(field_types.iter())
156        .map(|(name, ty)| {
157            quote! {
158                let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
159            }
160        });
161
162    let set_flag_decls_binary = set_flag_names.iter().map(|flag| {
163        quote! { let mut #flag = false; }
164    });
165
166    let init_checks_binary = field_names
167        .iter()
168        .zip(set_flag_names.iter())
169        .zip(field_name_strs.iter())
170        .map(|((_name, flag), name_str)| {
171            quote! {
172                if !#flag {
173                    return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
174                }
175            }
176        });
177
178    let field_inits_binary = field_names.iter().map(|name| {
179        quote! {
180            #name: unsafe { #name.assume_init() }
181        }
182    });
183
184    let expanded = quote! {
185        impl #impl_generics ::zero_postgres::conversion::FromRow<'_> for #name #ty_generics #where_clause {
186            fn from_row_text(
187                __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
188                __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
189            ) -> ::zero_postgres::Result<Self> {
190                #(#uninit_decls_text)*
191                #(#set_flag_decls_text)*
192
193                let mut __values = __row.iter();
194
195                for __field in __cols.iter() {
196                    let __value = __values.next().flatten();
197                    let __col_name = __field.name;
198                    match __col_name {
199                        #(#match_arms_text)*
200                        #fallback_arm
201                    }
202                }
203
204                #(#init_checks_text)*
205
206                Ok(Self {
207                    #(#field_inits_text),*
208                })
209            }
210
211            fn from_row_binary(
212                __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
213                __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
214            ) -> ::zero_postgres::Result<Self> {
215                #(#uninit_decls_binary)*
216                #(#set_flag_decls_binary)*
217
218                let mut __values = __row.iter();
219
220                for __field in __cols.iter() {
221                    let __value = __values.next().flatten();
222                    let __col_name = __field.name;
223                    match __col_name {
224                        #(#match_arms_binary)*
225                        #fallback_arm
226                    }
227                }
228
229                #(#init_checks_binary)*
230
231                Ok(Self {
232                    #(#field_inits_binary),*
233                })
234            }
235        }
236    };
237
238    TokenStream::from(expanded)
239}