1use heck::{
2 ToKebabCase, ToLowerCamelCase, ToPascalCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase,
3};
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{
7 parse_macro_input, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Meta, NestedMeta,
8};
9
10fn get_sqlx_field_rename(attrs: &[Attribute]) -> Option<String> {
11 for attr in attrs.iter() {
12 let meta = attr
13 .parse_meta()
14 .map_err(|e| syn::Error::new_spanned(attr, e))
15 .unwrap();
16 if let Meta::List(list) = meta {
17 for cattr in list.nested.iter() {
18 if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
19 let name = attr_ident.clone();
20 let name = name.path.get_ident().unwrap().to_string();
21 let name = name.as_str();
22 let ident = attr_ident.clone();
23 if name == "rename" {
24 let rename = match ident.lit {
25 syn::Lit::Str(val) => val,
26 _ => unreachable!("rename be string"),
27 }
28 .value();
29 return Some(rename);
30 }
31 }
32 }
33 }
34 }
35 None
36}
37fn change_sqlx_field_rename(change_type: &Option<String>, field_name: String) -> String {
38 if let Some(str) = change_type {
39 match str.as_str() {
40 "lowercase" => {
41 return field_name.to_lowercase();
42 }
43 "snake_case" => {
44 return field_name.to_snake_case();
45 }
46 "UPPERCASE" => {
47 return field_name.to_uppercase();
48 }
49 "SCREAMING_SNAKE_CASE" => {
50 return field_name.to_shouty_snake_case();
51 }
52 "kebab-case" => {
53 return field_name.to_kebab_case();
54 }
55 "camelCase" => {
56 return field_name.to_lower_camel_case();
57 }
58 "UpperCamelCase" => {
59 return field_name.to_upper_camel_case();
60 }
61 "PascalCase" => {
62 return field_name.to_pascal_case();
63 }
64 _ => {}
65 }
66 }
67 field_name
68}
69
70#[proc_macro_attribute]
71pub fn sqlx_model(args: TokenStream, item: TokenStream) -> TokenStream {
73 let input = parse_macro_input!(item as DeriveInput);
74 let struct_name = &input.ident;
75
76 let mut db_type = None;
77 let mut table_name = None;
78 let mut rename_all = None;
79 let mut table_pk = vec![];
80 let args = syn::parse_macro_input!(args as syn::AttributeArgs);
81 for cattr in args.iter() {
91 if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
92 let name = attr_ident.clone();
93 let name = name.path.get_ident().unwrap().to_string();
94 let name = name.as_str();
95 let ident = attr_ident.clone();
96 match name {
97 "db_type" => {
98 let val = match ident.lit {
99 syn::Lit::Str(val) => val,
100 _ => unreachable!("table name must be string"),
101 }
102 .value();
103 db_type = Some(val);
104 }
105 "table_name" => {
106 let val = match ident.lit {
107 syn::Lit::Str(val) => val,
108 _ => unreachable!("table name must be string"),
109 }
110 .value();
111 table_name = Some(val);
112 }
113 "table_pk" => {
114 let val = match ident.lit {
115 syn::Lit::Str(val) => val,
116 _ => unreachable!("table pk field must be string"),
117 }
118 .value();
119 table_pk.push(val);
120 }
121 "rename_all" => {
122 if let syn::Lit::Str(val) = ident.lit {
123 let str = &*val.value();
124 rename_all = Some(str.to_owned());
125 }
126 }
127 _ => {}
128 }
129 }
130 }
131 let db_type = quote::format_ident!("{}", db_type.expect("database type not set"));
134 let table_name = table_name.unwrap_or_else(|| {
135 let mut name = struct_name.to_string();
136 if name.clone().drain(0..5).collect::<String>() == "Model" {
137 name = name.drain(5..).collect::<String>();
138 }
139 if name.clone().drain(name.len() - 5..).collect::<String>() == "Model" {
140 name = name.drain(0..name.len() - 5).collect::<String>();
141 }
142 name.chars()
143 .enumerate()
144 .map(|(i, e)| {
145 if i != 0 && e as u8 >= 65 && e as u8 <= 90 {
146 format!("_{}", e.to_ascii_lowercase())
147 } else {
148 e.to_ascii_lowercase().to_string()
149 }
150 })
151 .collect::<Vec<String>>()
152 .join("")
153 });
154 let expanded = match &input.data {
155 Data::Struct(DataStruct { ref fields, .. }) => {
156 if let Fields::Named(ref fields_name) = fields {
157 let change_fields: Vec<_> = fields_name
158 .named
159 .iter()
160 .map(|field| {
161 let field_name = field.ident.as_ref().unwrap();
162 let str_field_name = match get_sqlx_field_rename(&field.attrs) {
163 Some(str) => str,
164 _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
165 };
166 let field_type = field.ty.clone();
167 quote! {
168 #field_name[#str_field_name]:#field_type
169 }
170 })
171 .collect();
172 let bind_fields: Vec<_> = fields_name
173 .named
174 .iter()
175 .map(|field| {
176 let field_name = field.ident.as_ref().unwrap();
177 let str_field_name = match get_sqlx_field_rename(&field.attrs) {
178 Some(str) => str,
179 _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
180 };
181 quote! {
182 #field_name[#str_field_name]
183 }
184 })
185 .collect();
186 let change_struct = quote::format_ident!("{}Ref", struct_name);
187 let mut pk_fields = vec![];
188 for field in fields_name.named.iter() {
189 let field_name = field.ident.as_ref().unwrap();
190 if table_pk.contains(&field_name.to_string()) {
191 let str_field_name = match get_sqlx_field_rename(&field.attrs) {
192 Some(str) => str,
193 _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
194 };
195 pk_fields.push(quote! {
196 #field_name[#str_field_name]
197 });
198 }
199 }
200 if pk_fields.is_empty() {
201 if let Some(field) = fields_name.named.iter().next() {
202 let field_name = field.ident.as_ref().unwrap();
203 let str_field_name = match get_sqlx_field_rename(&field.attrs) {
204 Some(str) => str,
205 _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
206 };
207 pk_fields.push(quote! {
208 #field_name[#str_field_name]
209 });
210 }
211 }
212 let implemented_show = quote! {
213 #input
214 sqlx_model::model_table_value_bind_define!(sqlx::#db_type,#struct_name,#table_name,{#(#bind_fields),*},{#(#pk_fields),*});
215 sqlx_model::model_table_ref_define!(sqlx::#db_type,#struct_name,#change_struct,{#(#change_fields),*});
216 };
217 implemented_show
218 } else {
219 panic!("sorry, may it's a complicated struct.");
220 }
221 }
222 _ => panic!("sorry, Show is not implemented for union or enum type."),
223 };
224 expanded.into()
225}
226
227#[proc_macro_attribute]
228pub fn sqlx_model_status(args: TokenStream, item: TokenStream) -> TokenStream {
230 let input = parse_macro_input!(item as DeriveInput);
231 let struct_name = &input.ident;
232 let args = syn::parse_macro_input!(args as syn::AttributeArgs);
233 let mut field_type = None;
234 for cattr in args.iter() {
245 if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
246 let name = attr_ident.clone();
247 let name = name.path.get_ident().unwrap().to_string();
248 let name = name.as_str();
249 let ident = attr_ident.clone();
250 if name == "field_type" {
251 field_type = Some(
252 match ident.lit {
253 syn::Lit::Str(val) => val,
254 _ => unreachable!("status type must be string"),
255 }
256 .value(),
257 );
258 }
259 }
260 }
261 let field_type = field_type.expect("status type not set");
262 let field_type = quote::format_ident!("{}", field_type);
265 let expanded = match input.data {
266 Data::Enum(DataEnum { ref variants, .. }) => {
267 let fields: Vec<_> = variants
268 .iter()
269 .map(|field| {
270 let field_name = field.ident.clone();
271 quote! {
272 #struct_name::#field_name
273 }
274 })
275 .collect();
276 quote! {
277 #input
278 sqlx_model::model_enum_status_define!(#struct_name,#field_type,{#(#fields),*});
279 }
280 }
281 _ => panic!("sorry, Show is not implemented for union or enum type."),
282 };
283 expanded.into()
284}