zero_postgres_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
4
5#[proc_macro_derive(FromRow, attributes(from_row))]
33pub fn derive_from_row(input: TokenStream) -> TokenStream {
34 let input = parse_macro_input!(input as DeriveInput);
35
36 let name = &input.ident;
37 let generics = &input.generics;
38 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40 let strict = input.attrs.iter().any(|attr| {
42 if !attr.path().is_ident("from_row") {
43 return false;
44 }
45 match &attr.meta {
46 Meta::List(list) => list.tokens.to_string().contains("strict"),
47 _ => false,
48 }
49 });
50
51 let fields = match &input.data {
52 Data::Struct(data) => match &data.fields {
53 Fields::Named(fields) => &fields.named,
54 _ => panic!("FromRow only supports structs with named fields"),
55 },
56 _ => panic!("FromRow only supports structs"),
57 };
58
59 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
60 let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
61 let field_name_strs: Vec<_> = field_names.iter().map(|n| n.to_string()).collect();
62
63 let uninit_decls = field_names
65 .iter()
66 .zip(field_types.iter())
67 .map(|(name, ty)| {
68 quote! {
69 let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
70 }
71 });
72
73 let set_flag_names: Vec<_> = field_names
75 .iter()
76 .map(|n| syn::Ident::new(&format!("{}_set", n), n.span()))
77 .collect();
78
79 let set_flag_decls = set_flag_names.iter().map(|flag| {
80 quote! { let mut #flag = false; }
81 });
82
83 let match_arms_text = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
85 quote! {
86 #name_str => {
87 let __val: #ty = match __value {
88 None => ::zero_postgres::conversion::FromWireValue::from_null()?,
89 Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_text(__field.type_oid(), __bytes)?,
90 };
91 #name.write(__val);
92 #flag = true;
93 }
94 }
95 });
96
97 let match_arms_binary = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
99 quote! {
100 #name_str => {
101 let __val: #ty = match __value {
102 None => ::zero_postgres::conversion::FromWireValue::from_null()?,
103 Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_binary(__field.type_oid(), __bytes)?,
104 };
105 #name.write(__val);
106 #flag = true;
107 }
108 }
109 });
110
111 let fallback_arm = if strict {
113 quote! {
114 __unknown => {
115 return Err(::zero_postgres::Error::Decode(format!("unknown column: {}", __unknown)));
116 }
117 }
118 } else {
119 quote! {
120 _ => {
121 }
123 }
124 };
125
126 let init_checks = field_names
128 .iter()
129 .zip(set_flag_names.iter())
130 .zip(field_name_strs.iter())
131 .map(|((_name, flag), name_str)| {
132 quote! {
133 if !#flag {
134 return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
135 }
136 }
137 });
138
139 let field_inits = field_names.iter().map(|name| {
141 quote! {
142 #name: unsafe { #name.assume_init() }
143 }
144 });
145
146 let uninit_decls_text = uninit_decls.clone();
148 let set_flag_decls_text = set_flag_decls.clone();
149 let init_checks_text = init_checks.clone();
150 let field_inits_text = field_inits.clone();
151
152 let uninit_decls_binary = field_names
154 .iter()
155 .zip(field_types.iter())
156 .map(|(name, ty)| {
157 quote! {
158 let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
159 }
160 });
161
162 let set_flag_decls_binary = set_flag_names.iter().map(|flag| {
163 quote! { let mut #flag = false; }
164 });
165
166 let init_checks_binary = field_names
167 .iter()
168 .zip(set_flag_names.iter())
169 .zip(field_name_strs.iter())
170 .map(|((_name, flag), name_str)| {
171 quote! {
172 if !#flag {
173 return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
174 }
175 }
176 });
177
178 let field_inits_binary = field_names.iter().map(|name| {
179 quote! {
180 #name: unsafe { #name.assume_init() }
181 }
182 });
183
184 let expanded = quote! {
185 impl #impl_generics ::zero_postgres::conversion::FromRow<'_> for #name #ty_generics #where_clause {
186 fn from_row_text(
187 __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
188 __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
189 ) -> ::zero_postgres::Result<Self> {
190 #(#uninit_decls_text)*
191 #(#set_flag_decls_text)*
192
193 let mut __values = __row.iter();
194
195 for __field in __cols.iter() {
196 let __value = __values.next().flatten();
197 let __col_name = __field.name;
198 match __col_name {
199 #(#match_arms_text)*
200 #fallback_arm
201 }
202 }
203
204 #(#init_checks_text)*
205
206 Ok(Self {
207 #(#field_inits_text),*
208 })
209 }
210
211 fn from_row_binary(
212 __cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
213 __row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
214 ) -> ::zero_postgres::Result<Self> {
215 #(#uninit_decls_binary)*
216 #(#set_flag_decls_binary)*
217
218 let mut __values = __row.iter();
219
220 for __field in __cols.iter() {
221 let __value = __values.next().flatten();
222 let __col_name = __field.name;
223 match __col_name {
224 #(#match_arms_binary)*
225 #fallback_arm
226 }
227 }
228
229 #(#init_checks_binary)*
230
231 Ok(Self {
232 #(#field_inits_binary),*
233 })
234 }
235 }
236 };
237
238 TokenStream::from(expanded)
239}