tokio_pg_mapper_derive/
lib.rs1extern crate proc_macro;
2#[macro_use]
3extern crate quote;
4extern crate syn;
5
6use proc_macro::TokenStream;
7
8use syn::{
9 Data, DataStruct, DeriveInput, Ident, ImplGenerics, Item,
10 Meta::{List, NameValue},
11 NestedMeta::Meta,
12 TypeGenerics, WhereClause,
13};
14
15#[proc_macro_derive(PostgresMapper, attributes(pg_mapper))]
16pub fn postgres_mapper(input: TokenStream) -> TokenStream {
17 let mut ast: DeriveInput = syn::parse(input).expect("Couldn't parse item");
18
19 impl_derive(&mut ast)
20}
21
22fn impl_derive(ast: &mut DeriveInput) -> TokenStream {
23 let name = &ast.ident;
24 let table_name = parse_table_attr(&ast);
25
26 let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl();
27
28 let s = match ast.data {
29 Data::Struct(ref s) => s,
30 _ => panic!("Enums or Unions can not be mapped"),
31 };
32
33 let tokio_pg_mapper = impl_tokio_pg_mapper(
34 s,
35 name,
36 &table_name,
37 impl_generics,
38 ty_generics,
39 where_clause,
40 );
41
42 let tokens = quote! {
43 #tokio_pg_mapper
44 };
45
46 tokens.into()
47}
48
49fn impl_tokio_pg_mapper(
50 s: &DataStruct,
51 name: &Ident,
52 table_name: &str,
53 impl_generics: &ImplGenerics,
54 ty_generics: &TypeGenerics,
55 where_clause: &Option<&WhereClause>,
56) -> Item {
57 let fields = s.fields.iter().map(|field| {
58 let ident = field.ident.as_ref().unwrap();
59 let ty = &field.ty;
60
61 let row_expr = format!(r##"{}"##, ident);
62 quote! {
63 #ident:row.try_get::<&str,#ty>(#row_expr)?
64 }
65 });
66
67 let ref_fields = s.fields.iter().map(|field| {
68 let ident = field.ident.as_ref().unwrap();
69 let ty = &field.ty;
70
71 let row_expr = format!(r##"{}"##, ident);
72 quote! {
73 #ident:row.try_get::<&str,#ty>(&#row_expr)?
74 }
75 });
76
77 let table_columns = s
78 .fields
79 .iter()
80 .map(|field| {
81 let ident = field
82 .ident
83 .as_ref()
84 .expect("Expected structfield identifier");
85 format!(" {0}.{1} ", table_name, ident)
86 })
87 .collect::<Vec<String>>()
88 .join(", ");
89
90 let columns = s
91 .fields
92 .iter()
93 .map(|field| {
94 let ident = field
95 .ident
96 .as_ref()
97 .expect("Expected structfield identifier");
98 format!(" {} ", ident)
99 })
100 .collect::<Vec<String>>()
101 .join(", ");
102
103 let tokens = quote! {
104 impl #impl_generics tokio_pg_mapper::FromTokioPostgresRow for #name #ty_generics #where_clause {
105 fn from_row(row: tokio_postgres::row::Row) -> ::std::result::Result<Self, tokio_pg_mapper::Error> {
106 Ok(Self {
107 #(#fields),*
108 })
109 }
110
111 fn from_row_ref(row: &tokio_postgres::row::Row) -> ::std::result::Result<Self, tokio_pg_mapper::Error> {
112 Ok(Self {
113 #(#ref_fields),*
114 })
115 }
116
117 fn sql_table() -> String {
118 #table_name.to_string()
119 }
120
121 fn sql_table_fields() -> String {
122 #table_columns.to_string()
123 }
124
125 fn sql_fields() -> String {
126 #columns.to_string()
127 }
128 }
129 };
130
131 syn::parse_quote!(#tokens)
132}
133
134fn get_mapper_meta_items(attr: &syn::Attribute) -> Option<Vec<syn::NestedMeta>> {
135 if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "pg_mapper" {
136 match attr.parse_meta() {
137 Ok(List(ref meta)) => Some(meta.nested.iter().cloned().collect()),
138 _ => {
139 panic!("declare table name: #[pg_mapper(table = \"foo\")]");
140 }
141 }
142 } else {
143 None
144 }
145}
146
147fn get_lit_str<'a>(
148 attr_name: Option<&Ident>,
149 lit: &'a syn::Lit,
150) -> ::std::result::Result<&'a syn::LitStr, ()> {
151 if let syn::Lit::Str(ref lit) = *lit {
152 Ok(lit)
153 } else {
154 if let Some(val) = attr_name {
155 panic!("expected pg_mapper {:?} attribute to be a string", val);
156 } else {
157 panic!("expected pg_mapper attribute to be a string");
158 }
159 #[allow(unreachable_code)]
160 Err(())
161 }
162}
163
164fn parse_table_attr(ast: &DeriveInput) -> String {
165 let mut table_name: Option<String> = None;
167
168 for meta_items in ast.attrs.iter().filter_map(get_mapper_meta_items) {
169 for meta_item in meta_items {
170 match meta_item {
171 Meta(NameValue(ref m)) if m.path.is_ident("table") => {
173 if let Ok(s) = get_lit_str(m.path.get_ident(), &m.lit) {
174 table_name = Some(s.value());
175 }
176 }
177 Meta(_) => {
178 panic!(format!("unknown pg_mapper container attribute",))
179 }
180 _ => {
181 panic!("unexpected literal in pg_mapper container attribute");
182 }
183 }
184 }
185 }
186
187 table_name.expect("declare table name: #[pg_mapper(table = \"foo\")]")
188}