Skip to main content

zero_postgres_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned};
4
5/// Derive macro for `FromRow` trait.
6///
7/// Generates an implementation that matches column names to struct fields.
8///
9/// # Example
10///
11/// ```ignore
12/// #[derive(FromRow)]
13/// struct User {
14///     name: String,
15///     age: i32,
16/// }
17/// ```
18///
19/// # Strict Mode
20///
21/// By default, unknown columns are silently skipped. Use `#[from_row(strict)]`
22/// to error on unknown columns:
23///
24/// ```ignore
25/// #[derive(FromRow)]
26/// #[from_row(strict)]
27/// struct User {
28///     name: String,
29///     age: i32,
30/// }
31/// ```
32#[proc_macro_derive(FromRow, attributes(from_row))]
33pub fn derive_from_row(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as DeriveInput);
35
36    let name = &input.ident;
37    let generics = &input.generics;
38    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40    // Check for #[from_row(strict)]
41    let strict = input.attrs.iter().any(|attr| {
42        if !attr.path().is_ident("from_row") {
43            return false;
44        }
45        match &attr.meta {
46            Meta::List(list) => list.tokens.to_string().contains("strict"),
47            _ => false,
48        }
49    });
50
51    let fields = match &input.data {
52        Data::Struct(data) => match &data.fields {
53            Fields::Named(fields) => &fields.named,
54            _ => panic!("FromRow only supports structs with named fields"),
55        },
56        _ => panic!("FromRow only supports structs"),
57    };
58
59    let field_names: Vec<_> = fields
60        .iter()
61        .map(|f| f.ident.as_ref().expect("named fields always have idents"))
62        .collect();
63    let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
64    let field_name_strs: Vec<_> = field_names.iter().map(|n| n.to_string()).collect();
65
66    // Generate MaybeUninit declarations
67    let uninit_decls = field_names
68        .iter()
69        .zip(field_types.iter())
70        .map(|(name, ty)| {
71            quote! {
72                let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
73            }
74        });
75
76    // Generate set flags
77    let set_flag_names: Vec<_> = field_names
78        .iter()
79        .map(|n| syn::Ident::new(&format!("{}_set", n), n.span()))
80        .collect();
81
82    let set_flag_decls = set_flag_names.iter().map(|flag| {
83        quote! { let mut #flag = false; }
84    });
85
86    // Generate match arms for text decoding
87    let match_arms_text = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
88        quote! {
89            #name_str => {
90                let __val: #ty = match __value {
91                    None => ::zero_postgres::conversion::FromWireValue::from_null()?,
92                    Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_text(__field.type_oid(), __bytes)?,
93                };
94                #name.write(__val);
95                #flag = true;
96            }
97        }
98    });
99
100    // Generate match arms for binary decoding
101    let match_arms_binary = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
102        quote! {
103            #name_str => {
104                let __val: #ty = match __value {
105                    None => ::zero_postgres::conversion::FromWireValue::from_null()?,
106                    Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_binary(__field.type_oid(), __bytes)?,
107                };
108                #name.write(__val);
109                #flag = true;
110            }
111        }
112    });
113
114    // Generate fallback arm based on strict mode
115    let fallback_arm = if strict {
116        quote! {
117            __unknown => {
118                return Err(::zero_postgres::Error::Decode(format!("unknown column: {}", __unknown)));
119            }
120        }
121    } else {
122        quote! {
123            _ => {
124                // Skip unknown column
125            }
126        }
127    };
128
129    // Generate initialization checks
130    let init_checks = field_names
131        .iter()
132        .zip(set_flag_names.iter())
133        .zip(field_name_strs.iter())
134        .map(|((_name, flag), name_str)| {
135            quote! {
136                if !#flag {
137                    return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
138                }
139            }
140        });
141
142    // Generate struct construction
143    let field_inits = field_names.iter().map(|name| {
144        quote! {
145            #name: unsafe { #name.assume_init() }
146        }
147    });
148
149    // Clone iterators for text implementation
150    let uninit_decls_text = uninit_decls.clone();
151    let set_flag_decls_text = set_flag_decls.clone();
152    let init_checks_text = init_checks.clone();
153    let field_inits_text = field_inits.clone();
154
155    // Clone for binary implementation
156    let uninit_decls_binary = field_names
157        .iter()
158        .zip(field_types.iter())
159        .map(|(name, ty)| {
160            quote! {
161                let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
162            }
163        });
164
165    let set_flag_decls_binary = set_flag_names.iter().map(|flag| {
166        quote! { let mut #flag = false; }
167    });
168
169    let init_checks_binary = field_names
170        .iter()
171        .zip(set_flag_names.iter())
172        .zip(field_name_strs.iter())
173        .map(|((_name, flag), name_str)| {
174            quote! {
175                if !#flag {
176                    return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
177                }
178            }
179        });
180
181    let field_inits_binary = field_names.iter().map(|name| {
182        quote! {
183            #name: unsafe { #name.assume_init() }
184        }
185    });
186
187    let expanded = quote! {
188        impl #impl_generics ::zero_postgres::conversion::FromRow<'_> for #name #ty_generics #where_clause {
189            fn from_row_text(
190                __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
191                __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
192            ) -> ::zero_postgres::Result<Self> {
193                #(#uninit_decls_text)*
194                #(#set_flag_decls_text)*
195
196                let mut __values = __row.iter();
197
198                for __field in __cols.iter() {
199                    let __value = __values.next().flatten();
200                    let __col_name = __field.name;
201                    match __col_name {
202                        #(#match_arms_text)*
203                        #fallback_arm
204                    }
205                }
206
207                #(#init_checks_text)*
208
209                Ok(Self {
210                    #(#field_inits_text),*
211                })
212            }
213
214            fn from_row_binary(
215                __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
216                __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
217            ) -> ::zero_postgres::Result<Self> {
218                #(#uninit_decls_binary)*
219                #(#set_flag_decls_binary)*
220
221                let mut __values = __row.iter();
222
223                for __field in __cols.iter() {
224                    let __value = __values.next().flatten();
225                    let __col_name = __field.name;
226                    match __col_name {
227                        #(#match_arms_binary)*
228                        #fallback_arm
229                    }
230                }
231
232                #(#init_checks_binary)*
233
234                Ok(Self {
235                    #(#field_inits_binary),*
236                })
237            }
238        }
239    };
240
241    TokenStream::from(expanded)
242}
243
244/// Derive macro for `RefFromRow` trait - zero-copy row decoding.
245///
246/// This macro generates a zero-copy implementation that returns a reference
247/// directly into the row buffer. It also derives zerocopy traits automatically.
248///
249/// # Requirements
250///
251/// - Struct must have `#[repr(C, packed)]` attribute
252/// - All fields must be `LengthPrefixed<T>` where `T` implements `FixedWireSize`
253/// - All columns must be `NOT NULL` (no `Option<T>` support)
254/// - Only works with binary format (extended queries)
255///
256/// # PostgreSQL Wire Format
257///
258/// PostgreSQL's binary protocol includes a 4-byte length prefix before each
259/// column value. Use `LengthPrefixed<T>` to account for this in the struct layout.
260///
261/// # Example
262///
263/// ```ignore
264/// use zero_postgres::conversion::ref_row::{RefFromRow, LengthPrefixed, I64BE, I32BE};
265///
266/// #[derive(RefFromRow)]
267/// #[repr(C, packed)]
268/// struct UserStats {
269///     user_id: LengthPrefixed<I64BE>,     // 4 + 8 = 12 bytes
270///     login_count: LengthPrefixed<I32BE>, // 4 + 4 = 8 bytes
271/// }
272/// // Total wire size: 20 bytes
273/// ```
274#[proc_macro_derive(RefFromRow)]
275pub fn derive_ref_from_row(input: TokenStream) -> TokenStream {
276    let input = parse_macro_input!(input as DeriveInput);
277
278    let name = &input.ident;
279
280    // Check for #[repr(C, packed)]
281    let has_repr_c_packed = input.attrs.iter().any(|attr| {
282        if !attr.path().is_ident("repr") {
283            return false;
284        }
285        let tokens = match &attr.meta {
286            Meta::List(list) => list.tokens.to_string(),
287            _ => return false,
288        };
289        tokens.contains("C") && tokens.contains("packed")
290    });
291
292    if !has_repr_c_packed {
293        return syn::Error::new(
294            input.ident.span(),
295            "RefFromRow requires #[repr(C, packed)] on the struct",
296        )
297        .to_compile_error()
298        .into();
299    }
300
301    let fields = match &input.data {
302        Data::Struct(data) => match &data.fields {
303            Fields::Named(fields) => &fields.named,
304            _ => {
305                return syn::Error::new(
306                    input.ident.span(),
307                    "RefFromRow only supports structs with named fields",
308                )
309                .to_compile_error()
310                .into();
311            }
312        },
313        _ => {
314            return syn::Error::new(input.ident.span(), "RefFromRow only supports structs")
315                .to_compile_error()
316                .into();
317        }
318    };
319
320    let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
321
322    // Generate compile-time assertions that all fields implement FixedWireSize
323    let wire_size_checks = field_types.iter().map(|ty| {
324        quote! {
325            const _: () = {
326                // This fails to compile if the type doesn't implement FixedWireSize
327                fn __assert_fixed_wire_size<T: ::zero_postgres::conversion::ref_row::FixedWireSize>() {}
328                fn __check() { __assert_fixed_wire_size::<#ty>(); }
329            };
330        }
331    });
332
333    // Calculate total wire size at compile time
334    let wire_size_sum = field_types.iter().map(|ty| {
335        quote! { <#ty as ::zero_postgres::conversion::ref_row::FixedWireSize>::WIRE_SIZE }
336    });
337
338    let expanded = quote! {
339        // Compile-time checks that all fields implement FixedWireSize
340        #(#wire_size_checks)*
341
342        // Derive zerocopy traits for zero-copy access
343        unsafe impl ::zerocopy::KnownLayout for #name {}
344        unsafe impl ::zerocopy::Immutable for #name {}
345        unsafe impl ::zerocopy::FromBytes for #name {}
346
347        impl<'a> ::zero_postgres::conversion::ref_row::RefFromRow<'a> for #name {
348            fn ref_from_row_binary(
349                _cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
350                row: ::zero_postgres::protocol::backend::query::DataRow<'a>,
351            ) -> ::zero_postgres::Result<&'a Self> {
352                // Expected size (includes length prefixes via LengthPrefixed<T>)
353                const EXPECTED_SIZE: usize = 0 #(+ #wire_size_sum)*;
354
355                // Get raw data including length prefixes
356                let data = row.raw_data();
357
358                if data.len() < EXPECTED_SIZE {
359                    return Err(::zero_postgres::Error::Decode(
360                        format!(
361                            "Row data too small: expected {} bytes, got {}",
362                            EXPECTED_SIZE,
363                            data.len()
364                        )
365                    ));
366                }
367
368                ::zerocopy::FromBytes::ref_from_bytes(&data[..EXPECTED_SIZE])
369                    .map_err(|e| ::zero_postgres::Error::Decode(
370                        format!("RefFromRow zerocopy error: {:?}", e)
371                    ))
372            }
373        }
374    };
375
376    TokenStream::from(expanded)
377}