sea_orm_verify/
lib.rs

1use proc_macro2::Ident;
2use quote::quote;
3use syn::{punctuated::Punctuated, token::Comma, DeriveInput, Fields, GenericArgument, Meta, PathArguments, Type};
4
5/// Derive to verify sea-orm Entity with sqlx at compile time against your db
6/// needs to have sqlx in your project dependencies as set DATABASE_URL or generate offline sqlx json data
7#[proc_macro_derive(Verify, attributes(verify))]
8pub fn verify(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9    let DeriveInput { ident: ty, generics, data, attrs, .. } = syn::parse(input).unwrap();
10
11    let fields = filter_fields(match data {
12        syn::Data::Struct(ref s) => &s.fields,
13        _ => panic!("Field can only be derived for structs"),
14    });
15
16    let mut table_name = None;
17    let mut schema_name = None;
18    let mut table_iden = false;
19    //let DeriveInput { ident, data, attrs, .. } = parse_macro_input!(input as DeriveInput);
20
21    attrs.iter().for_each(|attr| {
22        if attr.path().get_ident().map(|i| i == "sea_orm") != Some(true) {
23            return;
24        }
25
26        if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
27            for meta in list.iter() {
28                if let Meta::NameValue(nv) = meta {
29                    if let Some(ident) = nv.path.get_ident() {
30                        if ident == "table_name" {
31                            table_name = Some(nv.value.clone());
32                        } else if ident == "schema_name" {
33                            schema_name = Some(nv.value.clone());
34                        }
35                    }
36                } else if let Meta::Path(path) = meta {
37                    if let Some(ident) = path.get_ident() {
38                        if ident == "table_iden" {
39                            table_iden = true;
40                        }
41                    }
42                }
43            }
44        }
45    });
46
47    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48    let mut fields_sql = vec![];
49
50    for field in &fields {
51        let name_str = field.ident.to_string();
52        let name = name_str.trim_start_matches("r#");
53        let r#type = &field.r#type;
54        let mut null_override = "";
55        if field.not_null {
56            null_override = "!";
57        } else if field.null {
58            null_override = "?";
59        }
60
61        if field.type_override {
62            if let Type::Path(type_path) = r#type {
63                if let Some(path) = type_path.path.get_ident() {
64                    fields_sql.push(format!(r#"{} as "{}{}: {}""#, name, name, null_override, path));
65                } else {
66                    let outer_type = type_path.path.segments[0].ident.to_string();
67                    match outer_type.as_str() {
68                        "Option" | "Vec" => fields_sql.push(format!(
69                            r#"{} as "{}{}: {}""#,
70                            name,
71                            name,
72                            null_override,
73                            if let PathArguments::AngleBracketed(type_path) = &type_path.path.segments[0].arguments {
74                                if let GenericArgument::Type(Type::Path(type_path)) = type_path.args.first().unwrap() {
75                                    if outer_type == "Vec" {
76                                        format!("Vec<{}>", type_path.path.get_ident().unwrap())
77                                    } else {
78                                        type_path.path.get_ident().unwrap().to_string()
79                                    }
80                                } else {
81                                    panic!("unsupported type patch: {:?}", type_path);
82                                }
83                            } else {
84                                panic!("unsupported type patch: {:?}", type_path);
85                            }
86                        )),
87                        _ => panic!("unsupported type patch: {:?}", type_path),
88                    }
89                }
90            } else {
91                panic!("unsupported field type: {:?}", r#type);
92            }
93        } else if !null_override.is_empty() {
94            fields_sql.push(format!(r#""{}" as "{}{}""#, name, name, null_override));
95        } else {
96            fields_sql.push(format!(r#""{}""#, name));
97        }
98    }
99
100    let fields_sql = fields_sql.join(", ");
101
102    let sql = if let Some(schema_name) = schema_name {
103        format!("SELECT {} FROM {}.{}", fields_sql, quote! { #schema_name }, quote! { #table_name })
104    } else {
105        format!("SELECT {} FROM {}", fields_sql, quote! { #table_name })
106    };
107
108    let tokens = quote! {
109        impl #impl_generics #ty #ty_generics
110            #where_clause
111        {
112            #[allow(unused_must_use)]
113            async fn _verify() {
114                sqlx::query_as!(Self, #sql);
115            }
116        }
117    };
118    tokens.into()
119}
120/// Parse field attributes
121fn filter_fields(fields: &Fields) -> Vec<Field> {
122    fields
123        .iter()
124        .filter_map(|field| {
125            if field.ident.is_some() {
126                let mut type_override = false;
127                let mut not_null = false;
128                let mut null = false;
129                for attr in &field.attrs {
130                    if attr.path().get_ident().map(|i| i == "verify") == Some(true) {
131                        if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
132                            for meta in list.iter() {
133                                if let Meta::Path(path) = meta {
134                                    if let Some(ident) = path.get_ident() {
135                                        if ident == "type_override" {
136                                            type_override = true;
137                                        } else if ident == "not_null" {
138                                            not_null = true;
139                                        } else if ident == "null" {
140                                            null = true;
141                                        }
142                                    }
143                                }
144                            }
145                        }
146                    }
147                }
148                let field_ident = field.ident.as_ref().unwrap().clone();
149                let field_ty = field.ty.clone();
150                if not_null && null {
151                    panic!("not_null and null can not be set at the same time");
152                }
153
154                Some(Field::new(field_ident, field_ty, type_override, not_null, null))
155            } else {
156                None
157            }
158        })
159        .collect::<Vec<_>>()
160}
161
162#[derive(Debug)]
163struct Field {
164    pub ident: Ident,
165    pub r#type: Type,
166    pub type_override: bool,
167    pub not_null: bool,
168    pub null: bool,
169}
170
171impl Field {
172    pub fn new(ident: Ident, r#type: Type, type_override: bool, not_null: bool, null: bool) -> Self {
173        Self { ident, r#type, type_override, not_null, null }
174    }
175}