taitan_orm_parser/
sql_generator.rs

1use crate::{DatabaseType, FieldMapper, SqlType, TableDef};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use crate::condition_def::{ConditionDef, VariantsOrFields};
5
6#[derive(Debug, Default)]
7pub struct SqlGenerator;
8impl SqlGenerator {
9    pub fn gen_sql(
10        &self,
11        db_type: &DatabaseType,
12        sql_type: &SqlType,
13        table: &TableDef,
14    ) -> TokenStream {
15        match sql_type {
16            SqlType::Insert => self.gen_insert_sql(table, db_type),
17            SqlType::Upsert => self.gen_upsert_sql(table, db_type),
18        }
19    }
20
21    pub fn gen_insert_sql(&self, table_def: &TableDef, db_type: &DatabaseType) -> TokenStream {
22        let field_mapper = FieldMapper::new();
23        let table_name = field_mapper.escape(&table_def.table_name, db_type);
24        let fields = field_mapper.gen_names(&table_def.fields, &db_type);
25        let marks = field_mapper.gen_marks(&table_def.fields, &db_type);
26        let sql_template = format!("INSERT INTO {table_name} ({{}}) VALUES({{}})");
27        quote! {
28            let fields = #fields;
29            let marks = #marks;
30            std::borrow::Cow::Owned(format!(#sql_template, fields, marks))
31        }
32    }
33    pub fn gen_upsert_sql(&self, table_def: &TableDef, db_type: &DatabaseType) -> TokenStream {
34        let field_mapper = FieldMapper::new();
35        let table_name = field_mapper.escape(&table_def.table_name, db_type);
36        let fields = field_mapper.gen_names(&table_def.fields, db_type);
37        let primary_fields = table_def.get_primary_fields();
38        let primary_fields_stream = field_mapper.gen_names(primary_fields, db_type);
39        let non_primary_fields = table_def.get_not_primary_fields();
40        let upsert_sets_stream = field_mapper.gen_upsert_sets(non_primary_fields, db_type);
41
42        let marks = field_mapper.gen_marks(&table_def.fields, db_type);
43        return match db_type {
44            DatabaseType::MySql => {
45                let sql = format!(
46                    "INSERT INTO {table_name} ({{}}) VALUES({{}}) ON DUPLICATE KEY UPDATE {{}}"
47                );
48                quote! {
49                    let fields = #fields;
50                    let marks = #marks;
51                    let upsert_sets = #upsert_sets_stream;
52                    std::borrow::Cow::Owned(format!(#sql, fields, marks, upsert_sets))
53                }
54            }
55            DatabaseType::Postgres => {
56                let sql = format!("INSERT INTO {table_name} ({{}}) VALUES({{}}) ON CONFLICT ({{}}) DO UPDATE SET {{}}");
57                quote! {
58                    let fields = #fields;
59                    let marks = #marks;
60                    let primarys = #primary_fields_stream;
61                    let upsert_sets = #upsert_sets_stream;
62                    std::borrow::Cow::Owned(format!(#sql, fields, marks, primarys, upsert_sets))
63                }
64            }
65            DatabaseType::Sqlite => {
66                let sql = format!("INSERT INTO {table_name} ({{}}) VALUES({{}}) ON CONFLICT ({{}}) DO UPDATE SET {{}}");
67                quote! {
68                    let fields = #fields;
69                    let marks = #marks;
70                    let primarys = #primary_fields_stream;
71                    let upsert_sets = #upsert_sets_stream;
72                    std::borrow::Cow::Owned(format!(#sql, fields, marks, primarys, upsert_sets))
73                }
74            }
75        };
76    }
77
78    pub fn gen_update_set_sql(&self, table_def: &TableDef, db_type: &DatabaseType) -> TokenStream {
79        let field_mapper = FieldMapper::new();
80        field_mapper.gen_sets(&table_def.fields, db_type)
81    }
82
83    pub fn gen_where_sql(&self, condition_def: &ConditionDef, db_type: &DatabaseType) -> TokenStream {
84        let field_mapper = FieldMapper::new();
85        let mut stream = TokenStream::new();
86        match &condition_def.variants_or_fields {
87            VariantsOrFields::Variants(variants) => {
88                for variant in variants {
89                    let variant_name = format_ident!("{}", &variant.name);
90                    let idents = field_mapper.gen_idents(&variant.fields);
91                    // panic!("idents: {}", idents);
92                    let s = field_mapper.gen_conditions(&variant.fields, db_type, true);
93                    if variant.named {
94                        stream.extend(quote! {
95                            Self::#variant_name{ #idents }=> {
96                                #s
97                            }
98                        });
99                    } else {
100                        stream.extend(quote! {
101                    Self::#variant_name( #idents )=> {
102                        #s
103                    }
104                });
105                    }
106                }
107
108                quote! {
109                    let s = match self {
110                        #stream
111                    };
112                    std::borrow::Cow::Owned(s)
113                }
114            }
115            VariantsOrFields::Fields(fields) => {
116                let stream = field_mapper.gen_conditions(fields, db_type, false);
117                quote! {
118                    let s =  {
119                        #stream
120                    };
121                    std::borrow::Cow::Owned(s)
122                }
123            }
124        }
125
126
127    }
128
129    pub fn gen_select_sql(&self, table_def: &TableDef, db_type: &DatabaseType) -> TokenStream {
130        let field_mapper = FieldMapper::new();
131        field_mapper.gen_names(&table_def.fields, db_type)
132    }
133}