sqlx_plus_rs/
lib.rs

1extern crate proc_macro;
2extern crate sqlx_plus_builder;
3extern crate sqlx_plus_core;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use sqlx_plus_core::generate_create_table_sql;
8use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue};
9
10#[proc_macro_derive(SqlxPlus, attributes(table_name))]
11pub fn derive_sqlx_crud(input: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(input as DeriveInput);
13    let name = &input.ident;
14
15    // 解析 table_name 属性
16    let table_name = input
17        .attrs
18        .iter()
19        .find_map(|attr| {
20            if attr.path.is_ident("table_name") {
21                if let Meta::NameValue(MetaNameValue {
22                    lit: Lit::Str(lit_str),
23                    ..
24                }) = attr.parse_meta().ok()?
25                {
26                    Some(lit_str.value())
27                } else {
28                    None
29                }
30            } else {
31                None
32            }
33        })
34        .expect("Expected a table_name attribute");
35
36    // 获取结构体的字段
37    let fields: Vec<_> = if let syn::Data::Struct(data) = &input.data {
38        data.fields.iter().collect()
39    } else {
40        panic!("SqlxPlus can only be derived for structs");
41    };
42
43    // 生成列名和占位符
44    let columns: Vec<_> = fields
45        .iter()
46        .map(|f| f.ident.as_ref().unwrap().to_string())
47        .collect();
48    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
49
50    let columns_str = columns.join(", ");
51
52    // 将 Vec<&syn::Field> 转换为 Vec<syn::Field>
53    let fields: Vec<syn::Field> = fields.into_iter().cloned().collect();
54    // 调用生成 create_table_sql 的函数
55    let create_table_sql = generate_create_table_sql(&fields, &table_name);
56
57    let expanded = quote! {
58        use sqlx_plus_core::{DbPool, SqlxBase, DbTransaction};
59        use sqlx_plus_builder::SqlBuilder;
60
61        impl #name {
62
63            pub async fn create_table_if_not_exists(pool: &DbPool) -> Result<(), sqlx::Error> {
64                pool.create_table_if_not_exists(#table_name, #create_table_sql).await
65            }
66
67            pub async fn save(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
68                let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
69                let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, #columns_str, values_str);
70                pool.insert(&sql, transaction).await
71            }
72
73            pub async fn saver(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<Self, sqlx::Error> {
74                let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
75                let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", #table_name, #columns_str, values_str);
76                pool.insert_and_return_key(&sql, transaction).await
77            }
78
79            pub async fn get(pool: &DbPool, id: i64) -> Result<Option<Self>, sqlx::Error> {
80                let sql = format!("SELECT * FROM {} WHERE id = {}", #table_name, id);
81                pool.select_one(&sql).await
82            }
83
84            pub async fn update(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
85                let sql = format!("UPDATE {} SET ... WHERE id = {}", #table_name, entity.id);
86                pool.update(&sql, transaction).await
87            }
88
89            pub async fn delete(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, id: i64) -> Result<u64, sqlx::Error> {
90                let sql = format!("DELETE FROM {} WHERE id = {}", #table_name, id);
91                pool.delete(&sql, transaction).await
92            }
93
94            pub async fn list(pool: &DbPool) -> Result<Vec<Self>, sqlx::Error> {
95                let sql = format!("SELECT * FROM {}", #table_name);
96                pool.select_all(&sql).await
97            }
98
99            pub async fn page(pool: &DbPool, limit: i64, offset: i64) -> Result<Vec<Self>, sqlx::Error> {
100                let sql = format!("SELECT * FROM {} LIMIT {} OFFSET {}", #table_name, limit, offset);
101                pool.select_page(&sql).await
102            }
103
104            pub async fn count(pool: &DbPool) -> Result<i64, sqlx::Error> {
105                let sql = format!("SELECT COUNT(*) FROM {}", #table_name);
106                pool.count(&sql).await
107            }
108        }
109    };
110
111    TokenStream::from(expanded)
112}