ts_sql_helper_derive/
lib.rs1use quote::{quote, quote_spanned};
5use syn::{
6 Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
7 parse_quote, spanned::Spanned,
8};
9
10#[proc_macro_derive(FromRow)]
12pub fn derive_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
15
16 let name = input.ident;
17
18 let generics = add_trait_bounds(
20 input.generics,
21 parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
22 );
23 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
24
25 let Data::Struct(data_struct) = input.data else {
26 panic!("FromRow can only be derived on a struct")
27 };
28
29 let Fields::Named(fields) = data_struct.fields else {
30 panic!("FromRow can only be derived on a struct with named fields")
31 };
32
33 let each_field_from_row = fields.named.iter().filter_map(|f| {
34 let name = f.ident.as_ref()?;
35
36 let field_type = &f.ty;
37
38 let name_lit = name.to_string();
39
40 Some(quote_spanned! {f.span()=>
41 let #name: #field_type = row.try_get(#name_lit).ok()?;
42 })
43 });
44
45 let struct_fields = fields.named.iter().map(|f| {
46 let name = &f.ident;
47 quote_spanned! {f.span() => #name}
48 });
49
50 let expanded = quote! {
51 impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
53 fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Option<Self> {
54 #( #each_field_from_row )*
55
56 Some(Self {
57 #( #struct_fields ),*
58 })
59 }
60 }
61 };
62
63 proc_macro::TokenStream::from(expanded)
65}
66
67#[proc_macro_derive(FromSql)]
69pub fn derive_from_sql(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
70 let input = parse_macro_input!(input as DeriveInput);
72
73 if !matches!(input.data, Data::Enum(_)) {
74 panic!("FromSql can only be derived on an enum")
75 }
76
77 let name = input.ident;
78
79 let (repr, accepts, from_sql) = {
80 let mut repr_type = parse_quote!(&str);
81 let mut accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(TEXT));
82 let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
83 raw
84 )?);
85
86 for attr in input.attrs {
87 if !attr.path().is_ident("repr") {
88 continue;
89 }
90
91 let Ok(arg) = attr.parse_args::<Type>() else {
92 continue;
93 };
94
95 if arg == parse_quote!(i8) {
96 accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(CHAR));
97 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
98 raw
99 )?);
100 } else if arg == parse_quote!(i16) {
101 accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT2));
102 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
103 raw
104 )?);
105 } else if arg == parse_quote!(i32) {
106 accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT4));
107 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
108 raw
109 )?);
110 } else if arg == parse_quote!(i64) {
111 accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT8));
112 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
113 raw
114 )?);
115 } else {
116 continue;
117 }
118
119 repr_type = arg;
120 break;
121 }
122
123 (repr_type, accepts, from_sql)
124 };
125
126 let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
127 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
128
129 let expanded = quote! {
130 impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
131 fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
132 let raw_value = #from_sql;
133 let value = Self::try_from(raw_value)?;
134 Ok(value)
135 }
136
137 #accepts;
138
139 }
140 };
141
142 proc_macro::TokenStream::from(expanded)
143}
144
145fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
147 for param in &mut generics.params {
148 if let GenericParam::Type(ref mut type_param) = *param {
149 type_param.bounds.push(bounds.clone());
150 }
151 }
152 generics
153}