1use proc_macro2::Ident;
2use quote::quote;
3use syn::{punctuated::Punctuated, token::Comma, DeriveInput, Fields, GenericArgument, Meta, PathArguments, Type};
4
5#[proc_macro_derive(Verify, attributes(verify))]
8pub fn verify(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9 let DeriveInput { ident: ty, generics, data, attrs, .. } = syn::parse(input).unwrap();
10
11 let fields = filter_fields(match data {
12 syn::Data::Struct(ref s) => &s.fields,
13 _ => panic!("Field can only be derived for structs"),
14 });
15
16 let mut table_name = None;
17 let mut schema_name = None;
18 let mut table_iden = false;
19 attrs.iter().for_each(|attr| {
22 if attr.path().get_ident().map(|i| i == "sea_orm") != Some(true) {
23 return;
24 }
25
26 if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
27 for meta in list.iter() {
28 if let Meta::NameValue(nv) = meta {
29 if let Some(ident) = nv.path.get_ident() {
30 if ident == "table_name" {
31 table_name = Some(nv.value.clone());
32 } else if ident == "schema_name" {
33 schema_name = Some(nv.value.clone());
34 }
35 }
36 } else if let Meta::Path(path) = meta {
37 if let Some(ident) = path.get_ident() {
38 if ident == "table_iden" {
39 table_iden = true;
40 }
41 }
42 }
43 }
44 }
45 });
46
47 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48 let mut fields_sql = vec![];
49
50 for field in &fields {
51 let name_str = field.ident.to_string();
52 let name = name_str.trim_start_matches("r#");
53 let r#type = &field.r#type;
54 let mut null_override = "";
55 if field.not_null {
56 null_override = "!";
57 } else if field.null {
58 null_override = "?";
59 }
60
61 if field.type_override {
62 if let Type::Path(type_path) = r#type {
63 if let Some(path) = type_path.path.get_ident() {
64 fields_sql.push(format!(r#"{} as "{}{}: {}""#, name, name, null_override, path));
65 } else {
66 let outer_type = type_path.path.segments[0].ident.to_string();
67 match outer_type.as_str() {
68 "Option" | "Vec" => fields_sql.push(format!(
69 r#"{} as "{}{}: {}""#,
70 name,
71 name,
72 null_override,
73 if let PathArguments::AngleBracketed(type_path) = &type_path.path.segments[0].arguments {
74 if let GenericArgument::Type(Type::Path(type_path)) = type_path.args.first().unwrap() {
75 if outer_type == "Vec" {
76 format!("Vec<{}>", type_path.path.get_ident().unwrap())
77 } else {
78 type_path.path.get_ident().unwrap().to_string()
79 }
80 } else {
81 panic!("unsupported type patch: {:?}", type_path);
82 }
83 } else {
84 panic!("unsupported type patch: {:?}", type_path);
85 }
86 )),
87 _ => panic!("unsupported type patch: {:?}", type_path),
88 }
89 }
90 } else {
91 panic!("unsupported field type: {:?}", r#type);
92 }
93 } else if !null_override.is_empty() {
94 fields_sql.push(format!(r#""{}" as "{}{}""#, name, name, null_override));
95 } else {
96 fields_sql.push(format!(r#""{}""#, name));
97 }
98 }
99
100 let fields_sql = fields_sql.join(", ");
101
102 let sql = if let Some(schema_name) = schema_name {
103 format!("SELECT {} FROM {}.{}", fields_sql, quote! { #schema_name }, quote! { #table_name })
104 } else {
105 format!("SELECT {} FROM {}", fields_sql, quote! { #table_name })
106 };
107
108 let tokens = quote! {
109 impl #impl_generics #ty #ty_generics
110 #where_clause
111 {
112 #[allow(unused_must_use)]
113 async fn _verify() {
114 sqlx::query_as!(Self, #sql);
115 }
116 }
117 };
118 tokens.into()
119}
120fn filter_fields(fields: &Fields) -> Vec<Field> {
122 fields
123 .iter()
124 .filter_map(|field| {
125 if field.ident.is_some() {
126 let mut type_override = false;
127 let mut not_null = false;
128 let mut null = false;
129 for attr in &field.attrs {
130 if attr.path().get_ident().map(|i| i == "verify") == Some(true) {
131 if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
132 for meta in list.iter() {
133 if let Meta::Path(path) = meta {
134 if let Some(ident) = path.get_ident() {
135 if ident == "type_override" {
136 type_override = true;
137 } else if ident == "not_null" {
138 not_null = true;
139 } else if ident == "null" {
140 null = true;
141 }
142 }
143 }
144 }
145 }
146 }
147 }
148 let field_ident = field.ident.as_ref().unwrap().clone();
149 let field_ty = field.ty.clone();
150 if not_null && null {
151 panic!("not_null and null can not be set at the same time");
152 }
153
154 Some(Field::new(field_ident, field_ty, type_override, not_null, null))
155 } else {
156 None
157 }
158 })
159 .collect::<Vec<_>>()
160}
161
162#[derive(Debug)]
163struct Field {
164 pub ident: Ident,
165 pub r#type: Type,
166 pub type_override: bool,
167 pub not_null: bool,
168 pub null: bool,
169}
170
171impl Field {
172 pub fn new(ident: Ident, r#type: Type, type_override: bool, not_null: bool, null: bool) -> Self {
173 Self { ident, r#type, type_override, not_null, null }
174 }
175}