table_creator_macro2/
lib.rs

1use quote::quote;
2use syn::{ Data, DeriveInput, Field, Fields};
3
4//For the macro we shared must use the signature of proc_macro::TokenStream, but in the test we 
5//can't use the input as the proc_macro::TokenStre. So, we wrap the core function, which has
6//the proc_macro2::TokkenStream, into another function which has an proc_macro::TokenStream input.
7#[proc_macro_derive(TableCreator, attributes(table_ignore, primary_key))]
8pub fn derive_table_creator( input :proc_macro::TokenStream ) -> proc_macro::TokenStream {
9    let input: DeriveInput = syn::parse(input as proc_macro::TokenStream).expect("No input");
10    //this is the name of the struct.
11    let name = input.ident;
12
13    //This is the name of the table,  with the lowercase.
14    // struct Abcd => table_name: abcd.
15    let name_lower = name.to_string().to_lowercase();
16
17
18    let generics = input.generics;
19
20    let fields = match input.data {
21        Data::Struct(ref data) =>  match &data.fields {
22            Fields::Named(fileds) => &fileds.named,
23            _ => panic!("Only named fields are supported"),
24        },
25
26        _ => panic!("Only Struct are supported"),
27    };
28
29
30    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
31    
32    //The primary should be the unique one, so we declare primary_info like this is reasonalble.
33    //If the primary_key is not exsited, we just use a default one, named it with "id".
34    let mut primary_info : Option<(String, syn::Type)> = None;
35    
36    //We will filter the fields with the arttubute #[table_ignore], and find the primary 
37    let fields: Vec<&Field> = fields.iter()
38        .filter(|f| {
39            //This two is the flags to mark this field.
40            let mut is_primary = false;
41            let mut is_ignore = false;
42            //This is the counter for the attribute #[primary_key], if the counter != 1,
43            //just panic!
44            let mut counter = 0; 
45   
46
47            //check there is or not an attribute: #[table_ignor]
48            let _ = f.attrs.iter().for_each(|a|{
49                //check the attributes
50                if let Some(last) = a.path.segments.last() {
51                    if last.ident == "table_ignore" {
52                        is_ignore = true;
53                    }else if last.ident == "primary_key" {
54                        is_primary = true;
55                        counter += 1;
56                        eprintln!("find the attr primary_key in {}", &f.ident.as_ref().unwrap().to_string());
57                    }
58                }
59              
60            });
61            //if this field have attribute #[primary_key], set the primary_info once.
62            if is_primary {
63                let ident = f.ident.as_ref().unwrap().to_string();
64                let ty = f.ty.clone();
65                //check counter of the primary key.
66                if counter != 1 {
67                    panic!("Attribute error, can't have more #[primary_key] than once");
68                }
69                primary_info = Some((ident, ty));
70            }
71            //if have the attribute #[table_ignore], is_ignore will set with true, 
72            //that means we must skip this field, so we return the NOT value of is_ignore 
73            //to the filter, beacuse of the itor with true value would be skiped.
74            !is_ignore
75
76                
77        })
78        //check there is or not an attribute: #[primary_key].
79        //If it's tru, we will set the element as the primary key, or use a default "id" as the primary key instead.      
80    .collect();
81    
82    //store the ident of the column into the vector.  quote will pase it as the column name.
83    let column_ident: Vec<_>  = fields.iter()
84        .map(|f| f.ident.as_ref().unwrap())
85        .collect();
86    //store the transformed type into the vector. 
87    let fields_create: Vec<String> = fields.iter().map( |f|{
88        let field_name = f.ident.as_ref().unwrap().to_string();
89        
90        
91        let ty = match &f.ty {
92            syn::Type::Path(p) =>{
93                let last_indet = p.path.segments.last().unwrap().ident.clone();
94                match last_indet.to_string().as_str() {
95                    //sqlite's sql types   
96                    "i32" | "u32" | "i64" | "u64" => "INTEGER",
97                    "f32" | "f64" => "REAL",
98                    "String" | "str" => "TEXT",
99                    "bool" => "BOOLEAN",
100                    "NaiveDate" => "DATE",          // chrono::NaiveDate
101                    "NaiveDateTime" => "TIMESTAMP", // chrono::NaiveDateTime
102                    "UuidBytes" => "BYTEA",         //[u8;16]
103                    "Hash" => "BYTEA",              //[u8;32]
104                    
105                    _=> "TEXT" ,
106                }
107            }, 
108            _=> "TEXT",
109        };
110        //if the primary_info is not none.
111        if let Some(primary_info) = primary_info.clone() {
112            //if we have a primary key, add the string "PRIMARY KEY" after the field name.
113            //at here, we can unwrap it safely.
114            if field_name == primary_info.clone().0 {
115                //need some space before the "P"
116                return  format!("{} {} PRIMARY KEY", field_name, ty);
117            }
118        }
119        return format!("{} {}", field_name, ty);
120        
121    })
122    .collect(); 
123
124    //sotre the field name into the column_names.
125    //why we collect the filed name at hear? Because we need to know if there is a primary key.
126    let mut column_names:Vec<String> = Vec::new();
127    
128    fields.iter().for_each(|f|{
129        let field_name = f.ident.as_ref().unwrap().to_string();
130        column_names.push(field_name.clone());
131    });
132
133    //store the state of the primary_info.
134    let have_primary_key = primary_info.is_some();
135
136    //Note: the default primary key named "id" at the LAST position of the column_names.
137    if !have_primary_key {
138        // have no primary key, so we use the default "id" as the primary key.
139        // We must delete the this element below if we want to use the column_names as the real column names.
140        // The id is a serial primary key, so we don't need to set it in the insert sql.
141        column_names.push(" id SERIAL PRIMARY KEY ".to_string());
142    }
143
144    //this is the sql to create a tale we will return.
145    let create_sql: String;
146
147    //this is the sql to insert a row into the table.
148    let insert_sql: String;
149
150    //if we have a primary key.
151    if have_primary_key {
152        create_sql = format!("CREATE TABLE IF NOT EXISTS  {} ( {});", 
153            name_lower, fields_create.join(", "));
154        
155        //question mark is the placeholder for the values wich will be inserted.
156        //insertsql will be like this: "INSERT INTO A ( pid, hash, uuid ) VALUES ( $1, $2, $3 )", 
157        let mut  vec_palceholder = Vec::new();
158        for i in 1 ..= column_names.len() {
159            let i_str = i.to_string();
160            let placeholder = format!("${}", i_str);
161            vec_palceholder.push(placeholder);
162        }
163        insert_sql = format!("INSERT INTO {} ( {} ) VALUES ( {} )", 
164            name, column_names.join(", "), vec_palceholder.join(", "));
165        
166    }else {//have no primary key, but have the default "id" as the primary key.
167
168        //we need to know the "id" is the last element of the column_names.
169        //we must delet the "id" before we create a insert sql, for the id is a serial primary key, 
170        //so we don't need to set it in the insert sql.
171        create_sql = format!("CREATE TABLE IF NOT EXISTS {}  ( {} );",
172            name_lower, fields_create.join(", "));
173
174        //delete the "id" from the column_names.
175        column_names.pop();
176        
177        //question mark is the placeholder for the values wich will be inserted.
178        //insertsql will be like this: "INSERT INTO a ( pid, hash, uuid ) VALUES ( ?, ?, ? )", 
179        let vec_question_mark = vec!["?"; fields.len()];
180        insert_sql = format!("INSERT INTO {} ( {} ) VALUES ( {} )", 
181            name_lower, column_names.join(", "), vec_question_mark.join(", "));
182    }
183    //Note: at here, column_names is the real column names, not include the "id" at the end.
184
185    println!("create_sql: {}", create_sql);
186    println!("insert_sql: {}", insert_sql);
187   
188    
189    //the name of method
190    let expanded = quote! {
191
192        impl #impl_generics #name  #ty_generics #where_clause{
193            ///this is the sql to create a tale.
194            pub fn sql_create() -> String {
195                #create_sql.to_string()
196            }
197            
198            ///this is the sql to insert a row into the table.
199            pub fn sql_insert() -> String {
200                #insert_sql.to_string()
201            }          
202        
203            ///function to generate the arguments for sqlx query.
204            ///Note: this function only valid for the feature "sqlx-support".
205            /// 
206            ///If your struct A have derived the TableCreator, you can use it to generate the arguments 
207            ///for sqlx query like this:
208            /// A a = A {.....  }
209            /// let args = a.to_sqlx_args();
210            #[cfg(feature = "sqlx-support")]
211            pub fn to_sqlx_args(&self) ->sqlx::postgres::PgArguments {
212                let mut args = sqlx::postgres::PgArguments::default();
213                
214                // We need to import the trait `Arguments` to use the `add` method, it is required by sqlx for
215                // "add" method to work.
216                use sqlx::Arguments; 
217                #(args.add(&self.#column_ident);)*
218                args
219            }
220        } 
221        
222    };
223
224    proc_macro::TokenStream::from(expanded)
225}
226
227
228/*
229#[cfg(test)]
230mod test{
231
232    // 在宏crate的tests/debug_test.rs
233    use super::*;
234    use syn::{parse_quote, DeriveInput};
235
236    #[test]
237    fn debug_macro() {
238        let input: DeriveInput= parse_quote! {
239           #[derive(TableCreator,Debug)]
240            struct A {
241                #[primary_key]
242                pid: i32,
243                #[table_ignore]
244                hash: Hash,
245                uuid: UuidBytes,
246            }
247         
248        };
249     
250        let tokens = derive_table_creator_impl(proc_macro2::TokenStream::from( quote! { #input } ).into() ); 
251        
252
253        println!("{}", tokens);
254    }
255}
256*/