Skip to main content

tokio_pg_mapper_derive/
lib.rs

1extern crate proc_macro;
2#[macro_use]
3extern crate quote;
4extern crate syn;
5
6use proc_macro::TokenStream;
7
8use syn::{
9    Data, DataStruct, DeriveInput, Ident, ImplGenerics, Item,
10    Meta::{List, NameValue},
11    NestedMeta::Meta,
12    TypeGenerics, WhereClause,
13};
14
15#[proc_macro_derive(PostgresMapper, attributes(pg_mapper))]
16pub fn postgres_mapper(input: TokenStream) -> TokenStream {
17    let mut ast: DeriveInput = syn::parse(input).expect("Couldn't parse item");
18
19    impl_derive(&mut ast)
20}
21
22fn impl_derive(ast: &mut DeriveInput) -> TokenStream {
23    let name = &ast.ident;
24    let table_name = parse_table_attr(&ast);
25
26    let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl();
27
28    let s = match ast.data {
29        Data::Struct(ref s) => s,
30        _ => panic!("Enums or Unions can not be mapped"),
31    };
32
33    let tokio_pg_mapper = impl_tokio_pg_mapper(
34        s,
35        name,
36        &table_name,
37        impl_generics,
38        ty_generics,
39        where_clause,
40    );
41
42    let tokens = quote! {
43        #tokio_pg_mapper
44    };
45
46    tokens.into()
47}
48
49fn impl_tokio_pg_mapper(
50    s: &DataStruct,
51    name: &Ident,
52    table_name: &str,
53    impl_generics: &ImplGenerics,
54    ty_generics: &TypeGenerics,
55    where_clause: &Option<&WhereClause>,
56) -> Item {
57    let fields = s.fields.iter().map(|field| {
58        let ident = field.ident.as_ref().unwrap();
59        let ty = &field.ty;
60
61        let row_expr = format!(r##"{}"##, ident);
62        quote! {
63            #ident:row.try_get::<&str,#ty>(#row_expr)?
64        }
65    });
66
67    let ref_fields = s.fields.iter().map(|field| {
68        let ident = field.ident.as_ref().unwrap();
69        let ty = &field.ty;
70
71        let row_expr = format!(r##"{}"##, ident);
72        quote! {
73            #ident:row.try_get::<&str,#ty>(&#row_expr)?
74        }
75    });
76
77    let table_columns = s
78        .fields
79        .iter()
80        .map(|field| {
81            let ident = field
82                .ident
83                .as_ref()
84                .expect("Expected structfield identifier");
85            format!(" {0}.{1} ", table_name, ident)
86        })
87        .collect::<Vec<String>>()
88        .join(", ");
89
90    let columns = s
91        .fields
92        .iter()
93        .map(|field| {
94            let ident = field
95                .ident
96                .as_ref()
97                .expect("Expected structfield identifier");
98            format!(" {} ", ident)
99        })
100        .collect::<Vec<String>>()
101        .join(", ");
102
103    let tokens = quote! {
104        impl #impl_generics tokio_pg_mapper::FromTokioPostgresRow for #name #ty_generics #where_clause {
105            fn from_row(row: tokio_postgres::row::Row) -> ::std::result::Result<Self, tokio_pg_mapper::Error> {
106                Ok(Self {
107                    #(#fields),*
108                })
109            }
110
111            fn from_row_ref(row: &tokio_postgres::row::Row) -> ::std::result::Result<Self, tokio_pg_mapper::Error> {
112                Ok(Self {
113                    #(#ref_fields),*
114                })
115            }
116
117            fn sql_table() -> String {
118                #table_name.to_string()
119            }
120
121            fn sql_table_fields() -> String {
122                #table_columns.to_string()
123            }
124
125            fn sql_fields() -> String {
126                #columns.to_string()
127            }
128        }
129    };
130
131    syn::parse_quote!(#tokens)
132}
133
134fn get_mapper_meta_items(attr: &syn::Attribute) -> Option<Vec<syn::NestedMeta>> {
135    if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "pg_mapper" {
136        match attr.parse_meta() {
137            Ok(List(ref meta)) => Some(meta.nested.iter().cloned().collect()),
138            _ => {
139                panic!("declare table name: #[pg_mapper(table = \"foo\")]");
140            }
141        }
142    } else {
143        None
144    }
145}
146
147fn get_lit_str<'a>(
148    attr_name: Option<&Ident>,
149    lit: &'a syn::Lit,
150) -> ::std::result::Result<&'a syn::LitStr, ()> {
151    if let syn::Lit::Str(ref lit) = *lit {
152        Ok(lit)
153    } else {
154        if let Some(val) = attr_name {
155            panic!("expected pg_mapper {:?} attribute to be a string", val);
156        } else {
157            panic!("expected pg_mapper attribute to be a string");
158        }
159        #[allow(unreachable_code)]
160        Err(())
161    }
162}
163
164fn parse_table_attr(ast: &DeriveInput) -> String {
165    // Parse `#[pg_mapper(table = "foo")]`
166    let mut table_name: Option<String> = None;
167
168    for meta_items in ast.attrs.iter().filter_map(get_mapper_meta_items) {
169        for meta_item in meta_items {
170            match meta_item {
171                // Parse `#[pg_mapper(table = "foo")]`
172                Meta(NameValue(ref m)) if m.path.is_ident("table") => {
173                    if let Ok(s) = get_lit_str(m.path.get_ident(), &m.lit) {
174                        table_name = Some(s.value());
175                    }
176                }
177                Meta(_) => {
178                    panic!(format!("unknown pg_mapper container attribute",))
179                }
180                _ => {
181                    panic!("unexpected literal in pg_mapper container attribute");
182                }
183            }
184        }
185    }
186
187    table_name.expect("declare table name: #[pg_mapper(table = \"foo\")]")
188}