sqlx_plus_rs/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
extern crate proc_macro;
extern crate sqlx_plus_builder;
extern crate sqlx_plus_core;

use proc_macro::TokenStream;
use quote::quote;
use sqlx_plus_core::generate_create_table_sql;
use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue};

#[proc_macro_derive(SqlxPlus, attributes(table_name))]
pub fn derive_sqlx_crud(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    // 解析 table_name 属性
    let table_name = input
        .attrs
        .iter()
        .find_map(|attr| {
            if attr.path.is_ident("table_name") {
                if let Meta::NameValue(MetaNameValue {
                    lit: Lit::Str(lit_str),
                    ..
                }) = attr.parse_meta().ok()?
                {
                    Some(lit_str.value())
                } else {
                    None
                }
            } else {
                None
            }
        })
        .expect("Expected a table_name attribute");

    // 获取结构体的字段
    let fields: Vec<_> = if let syn::Data::Struct(data) = &input.data {
        data.fields.iter().collect()
    } else {
        panic!("SqlxPlus can only be derived for structs");
    };

    // 生成列名和占位符
    let columns: Vec<_> = fields
        .iter()
        .map(|f| f.ident.as_ref().unwrap().to_string())
        .collect();
    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();

    let columns_str = columns.join(", ");

    // 将 Vec<&syn::Field> 转换为 Vec<syn::Field>
    let fields: Vec<syn::Field> = fields.into_iter().cloned().collect();
    // 调用生成 create_table_sql 的函数
    let create_table_sql = generate_create_table_sql(&fields, &table_name);

    let expanded = quote! {
        use sqlx_plus_core::{DbPool, SqlxBase, DbTransaction};
        use sqlx_plus_builder::SqlBuilder;

        impl #name {

            pub async fn create_table_if_not_exists(pool: &DbPool) -> Result<(), sqlx::Error> {
                pool.create_table_if_not_exists(#table_name, #create_table_sql).await
            }

            pub async fn save(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
                let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
                let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, #columns_str, values_str);
                pool.insert(&sql, transaction).await
            }

            pub async fn saver(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<Self, sqlx::Error> {
                let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
                let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", #table_name, #columns_str, values_str);
                pool.insert_and_return_key(&sql, transaction).await
            }

            pub async fn get(pool: &DbPool, id: i64) -> Result<Option<Self>, sqlx::Error> {
                let sql = format!("SELECT * FROM {} WHERE id = {}", #table_name, id);
                pool.select_one(&sql).await
            }

            pub async fn update(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
                let sql = format!("UPDATE {} SET ... WHERE id = {}", #table_name, entity.id);
                pool.update(&sql, transaction).await
            }

            pub async fn delete(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, id: i64) -> Result<u64, sqlx::Error> {
                let sql = format!("DELETE FROM {} WHERE id = {}", #table_name, id);
                pool.delete(&sql, transaction).await
            }

            pub async fn list(pool: &DbPool) -> Result<Vec<Self>, sqlx::Error> {
                let sql = format!("SELECT * FROM {}", #table_name);
                pool.select_all(&sql).await
            }

            pub async fn page(pool: &DbPool, limit: i64, offset: i64) -> Result<Vec<Self>, sqlx::Error> {
                let sql = format!("SELECT * FROM {} LIMIT {} OFFSET {}", #table_name, limit, offset);
                pool.select_page(&sql).await
            }

            pub async fn count(pool: &DbPool) -> Result<i64, sqlx::Error> {
                let sql = format!("SELECT COUNT(*) FROM {}", #table_name);
                pool.count(&sql).await
            }
        }
    };

    TokenStream::from(expanded)
}