zero_postgres_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned};
4
5#[proc_macro_derive(FromRow, attributes(from_row))]
33pub fn derive_from_row(input: TokenStream) -> TokenStream {
34 let input = parse_macro_input!(input as DeriveInput);
35
36 let name = &input.ident;
37 let generics = &input.generics;
38 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40 let strict = input.attrs.iter().any(|attr| {
42 if !attr.path().is_ident("from_row") {
43 return false;
44 }
45 match &attr.meta {
46 Meta::List(list) => list.tokens.to_string().contains("strict"),
47 _ => false,
48 }
49 });
50
51 let fields = match &input.data {
52 Data::Struct(data) => match &data.fields {
53 Fields::Named(fields) => &fields.named,
54 _ => panic!("FromRow only supports structs with named fields"),
55 },
56 _ => panic!("FromRow only supports structs"),
57 };
58
59 let field_names: Vec<_> = fields
60 .iter()
61 .map(|f| f.ident.as_ref().expect("named fields always have idents"))
62 .collect();
63 let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
64 let field_name_strs: Vec<_> = field_names.iter().map(|n| n.to_string()).collect();
65
66 let uninit_decls = field_names
68 .iter()
69 .zip(field_types.iter())
70 .map(|(name, ty)| {
71 quote! {
72 let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
73 }
74 });
75
76 let set_flag_names: Vec<_> = field_names
78 .iter()
79 .map(|n| syn::Ident::new(&format!("{}_set", n), n.span()))
80 .collect();
81
82 let set_flag_decls = set_flag_names.iter().map(|flag| {
83 quote! { let mut #flag = false; }
84 });
85
86 let match_arms_text = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
88 quote! {
89 #name_str => {
90 let __val: #ty = match __value {
91 None => ::zero_postgres::conversion::FromWireValue::from_null()?,
92 Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_text(__field.type_oid(), __bytes)?,
93 };
94 #name.write(__val);
95 #flag = true;
96 }
97 }
98 });
99
100 let match_arms_binary = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
102 quote! {
103 #name_str => {
104 let __val: #ty = match __value {
105 None => ::zero_postgres::conversion::FromWireValue::from_null()?,
106 Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_binary(__field.type_oid(), __bytes)?,
107 };
108 #name.write(__val);
109 #flag = true;
110 }
111 }
112 });
113
114 let fallback_arm = if strict {
116 quote! {
117 __unknown => {
118 return Err(::zero_postgres::Error::Decode(format!("unknown column: {}", __unknown)));
119 }
120 }
121 } else {
122 quote! {
123 _ => {
124 }
126 }
127 };
128
129 let init_checks = field_names
131 .iter()
132 .zip(set_flag_names.iter())
133 .zip(field_name_strs.iter())
134 .map(|((_name, flag), name_str)| {
135 quote! {
136 if !#flag {
137 return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
138 }
139 }
140 });
141
142 let field_inits = field_names.iter().map(|name| {
144 quote! {
145 #name: unsafe { #name.assume_init() }
146 }
147 });
148
149 let uninit_decls_text = uninit_decls.clone();
151 let set_flag_decls_text = set_flag_decls.clone();
152 let init_checks_text = init_checks.clone();
153 let field_inits_text = field_inits.clone();
154
155 let uninit_decls_binary = field_names
157 .iter()
158 .zip(field_types.iter())
159 .map(|(name, ty)| {
160 quote! {
161 let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
162 }
163 });
164
165 let set_flag_decls_binary = set_flag_names.iter().map(|flag| {
166 quote! { let mut #flag = false; }
167 });
168
169 let init_checks_binary = field_names
170 .iter()
171 .zip(set_flag_names.iter())
172 .zip(field_name_strs.iter())
173 .map(|((_name, flag), name_str)| {
174 quote! {
175 if !#flag {
176 return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
177 }
178 }
179 });
180
181 let field_inits_binary = field_names.iter().map(|name| {
182 quote! {
183 #name: unsafe { #name.assume_init() }
184 }
185 });
186
187 let expanded = quote! {
188 impl #impl_generics ::zero_postgres::conversion::FromRow<'_> for #name #ty_generics #where_clause {
189 fn from_row_text(
190 __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
191 __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
192 ) -> ::zero_postgres::Result<Self> {
193 #(#uninit_decls_text)*
194 #(#set_flag_decls_text)*
195
196 let mut __values = __row.iter();
197
198 for __field in __cols.iter() {
199 let __value = __values.next().flatten();
200 let __col_name = __field.name;
201 match __col_name {
202 #(#match_arms_text)*
203 #fallback_arm
204 }
205 }
206
207 #(#init_checks_text)*
208
209 Ok(Self {
210 #(#field_inits_text),*
211 })
212 }
213
214 fn from_row_binary(
215 __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
216 __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
217 ) -> ::zero_postgres::Result<Self> {
218 #(#uninit_decls_binary)*
219 #(#set_flag_decls_binary)*
220
221 let mut __values = __row.iter();
222
223 for __field in __cols.iter() {
224 let __value = __values.next().flatten();
225 let __col_name = __field.name;
226 match __col_name {
227 #(#match_arms_binary)*
228 #fallback_arm
229 }
230 }
231
232 #(#init_checks_binary)*
233
234 Ok(Self {
235 #(#field_inits_binary),*
236 })
237 }
238 }
239 };
240
241 TokenStream::from(expanded)
242}
243
244#[proc_macro_derive(RefFromRow)]
275pub fn derive_ref_from_row(input: TokenStream) -> TokenStream {
276 let input = parse_macro_input!(input as DeriveInput);
277
278 let name = &input.ident;
279
280 let has_repr_c_packed = input.attrs.iter().any(|attr| {
282 if !attr.path().is_ident("repr") {
283 return false;
284 }
285 let tokens = match &attr.meta {
286 Meta::List(list) => list.tokens.to_string(),
287 _ => return false,
288 };
289 tokens.contains("C") && tokens.contains("packed")
290 });
291
292 if !has_repr_c_packed {
293 return syn::Error::new(
294 input.ident.span(),
295 "RefFromRow requires #[repr(C, packed)] on the struct",
296 )
297 .to_compile_error()
298 .into();
299 }
300
301 let fields = match &input.data {
302 Data::Struct(data) => match &data.fields {
303 Fields::Named(fields) => &fields.named,
304 _ => {
305 return syn::Error::new(
306 input.ident.span(),
307 "RefFromRow only supports structs with named fields",
308 )
309 .to_compile_error()
310 .into();
311 }
312 },
313 _ => {
314 return syn::Error::new(input.ident.span(), "RefFromRow only supports structs")
315 .to_compile_error()
316 .into();
317 }
318 };
319
320 let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
321
322 let wire_size_checks = field_types.iter().map(|ty| {
324 quote! {
325 const _: () = {
326 fn __assert_fixed_wire_size<T: ::zero_postgres::conversion::ref_row::FixedWireSize>() {}
328 fn __check() { __assert_fixed_wire_size::<#ty>(); }
329 };
330 }
331 });
332
333 let wire_size_sum = field_types.iter().map(|ty| {
335 quote! { <#ty as ::zero_postgres::conversion::ref_row::FixedWireSize>::WIRE_SIZE }
336 });
337
338 let expanded = quote! {
339 #(#wire_size_checks)*
341
342 unsafe impl ::zerocopy::KnownLayout for #name {}
344 unsafe impl ::zerocopy::Immutable for #name {}
345 unsafe impl ::zerocopy::FromBytes for #name {}
346
347 impl<'a> ::zero_postgres::conversion::ref_row::RefFromRow<'a> for #name {
348 fn ref_from_row_binary(
349 _cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
350 row: ::zero_postgres::protocol::backend::query::DataRow<'a>,
351 ) -> ::zero_postgres::Result<&'a Self> {
352 const EXPECTED_SIZE: usize = 0 #(+ #wire_size_sum)*;
354
355 let data = row.raw_data();
357
358 if data.len() < EXPECTED_SIZE {
359 return Err(::zero_postgres::Error::Decode(
360 format!(
361 "Row data too small: expected {} bytes, got {}",
362 EXPECTED_SIZE,
363 data.len()
364 )
365 ));
366 }
367
368 ::zerocopy::FromBytes::ref_from_bytes(&data[..EXPECTED_SIZE])
369 .map_err(|e| ::zero_postgres::Error::Decode(
370 format!("RefFromRow zerocopy error: {:?}", e)
371 ))
372 }
373 }
374 };
375
376 TokenStream::from(expanded)
377}