ts_sql_helper_derive/
lib.rs1use quote::{quote, quote_spanned};
5use syn::{
6 Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
7 parse_quote, spanned::Spanned,
8};
9
10#[proc_macro_derive(FromRow)]
12pub fn derive_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
15
16 let name = input.ident;
17
18 let generics = add_trait_bounds(
20 input.generics,
21 parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
22 );
23 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
24
25 let Data::Struct(data_struct) = input.data else {
26 panic!("FromRow can only be derived on a struct")
27 };
28
29 let Fields::Named(fields) = data_struct.fields else {
30 panic!("FromRow can only be derived on a struct with named fields")
31 };
32
33 let each_field_from_row = fields.named.iter().filter_map(|f| {
34 let name = f.ident.as_ref()?;
35 let name_lit = name.to_string();
36 let field_type = &f.ty;
37
38 Some(quote_spanned! {f.span()=>
39 let #name: #field_type = row.try_get(#name_lit)?;
40 })
41 });
42
43 let struct_fields = fields.named.iter().map(|f| {
44 let name = &f.ident;
45 quote_spanned! {f.span() => #name}
46 });
47
48 let expanded = quote! {
49 impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
51 fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
52 #( #each_field_from_row )*
53
54 Ok(Self {
55 #( #struct_fields ),*
56 })
57 }
58 }
59 };
60
61 proc_macro::TokenStream::from(expanded)
63}
64
65#[proc_macro_derive(FromSql)]
67pub fn derive_from_sql(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
68 let input = parse_macro_input!(input as DeriveInput);
70
71 if !matches!(input.data, Data::Enum(_)) {
72 panic!("FromSql can only be derived on an enum")
73 }
74
75 let name = input.ident;
76
77 let (repr, accepts, from_sql) = {
78 let mut repr_type = parse_quote!(&str);
79 let mut accepts: Vec<Type> = vec![
80 parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
81 parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
82 ];
83 let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
84 raw
85 )?);
86
87 for attr in input.attrs {
88 if !attr.path().is_ident("repr") {
89 continue;
90 }
91
92 let Ok(arg) = attr.parse_args::<Type>() else {
93 continue;
94 };
95
96 if arg == parse_quote!(i8) {
97 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
98 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
99 raw
100 )?);
101 } else if arg == parse_quote!(i16) {
102 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
103 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
104 raw
105 )?);
106 } else if arg == parse_quote!(i32) {
107 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
108 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
109 raw
110 )?);
111 } else if arg == parse_quote!(i64) {
112 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
113 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
114 raw
115 )?);
116 } else {
117 continue;
118 }
119
120 repr_type = arg;
121 break;
122 }
123
124 (repr_type, accepts, from_sql)
125 };
126
127 let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
128 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
129
130 let expanded = quote! {
131 impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
132 fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
133 let raw_value = #from_sql;
134 let value = Self::try_from(raw_value)?;
135 Ok(value)
136 }
137
138 fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
139 match (*ty) {
140 #(#accepts)|* => true,
141 _ => false,
142 }
143 }
144 }
145 };
146
147 proc_macro::TokenStream::from(expanded)
148}
149
150fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
152 for param in &mut generics.params {
153 if let GenericParam::Type(ref mut type_param) = *param {
154 type_param.bounds.push(bounds.clone());
155 }
156 }
157 generics
158}