1extern crate proc_macro;
2extern crate sqlx_plus_builder;
3extern crate sqlx_plus_core;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use sqlx_plus_core::generate_create_table_sql;
8use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue};
9
10#[proc_macro_derive(SqlxPlus, attributes(table_name))]
11pub fn derive_sqlx_crud(input: TokenStream) -> TokenStream {
12 let input = parse_macro_input!(input as DeriveInput);
13 let name = &input.ident;
14
15 let table_name = input
17 .attrs
18 .iter()
19 .find_map(|attr| {
20 if attr.path.is_ident("table_name") {
21 if let Meta::NameValue(MetaNameValue {
22 lit: Lit::Str(lit_str),
23 ..
24 }) = attr.parse_meta().ok()?
25 {
26 Some(lit_str.value())
27 } else {
28 None
29 }
30 } else {
31 None
32 }
33 })
34 .expect("Expected a table_name attribute");
35
36 let fields: Vec<_> = if let syn::Data::Struct(data) = &input.data {
38 data.fields.iter().collect()
39 } else {
40 panic!("SqlxPlus can only be derived for structs");
41 };
42
43 let columns: Vec<_> = fields
45 .iter()
46 .map(|f| f.ident.as_ref().unwrap().to_string())
47 .collect();
48 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
49
50 let columns_str = columns.join(", ");
51
52 let fields: Vec<syn::Field> = fields.into_iter().cloned().collect();
54 let create_table_sql = generate_create_table_sql(&fields, &table_name);
56
57 let expanded = quote! {
58 use sqlx_plus_core::{DbPool, SqlxBase, DbTransaction};
59 use sqlx_plus_builder::SqlBuilder;
60
61 impl #name {
62
63 pub async fn create_table_if_not_exists(pool: &DbPool) -> Result<(), sqlx::Error> {
64 pool.create_table_if_not_exists(#table_name, #create_table_sql).await
65 }
66
67 pub async fn save(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
68 let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
69 let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, #columns_str, values_str);
70 pool.insert(&sql, transaction).await
71 }
72
73 pub async fn saver(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<Self, sqlx::Error> {
74 let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
75 let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", #table_name, #columns_str, values_str);
76 pool.insert_and_return_key(&sql, transaction).await
77 }
78
79 pub async fn get(pool: &DbPool, id: i64) -> Result<Option<Self>, sqlx::Error> {
80 let sql = format!("SELECT * FROM {} WHERE id = {}", #table_name, id);
81 pool.select_one(&sql).await
82 }
83
84 pub async fn update(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
85 let sql = format!("UPDATE {} SET ... WHERE id = {}", #table_name, entity.id);
86 pool.update(&sql, transaction).await
87 }
88
89 pub async fn delete(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, id: i64) -> Result<u64, sqlx::Error> {
90 let sql = format!("DELETE FROM {} WHERE id = {}", #table_name, id);
91 pool.delete(&sql, transaction).await
92 }
93
94 pub async fn list(pool: &DbPool) -> Result<Vec<Self>, sqlx::Error> {
95 let sql = format!("SELECT * FROM {}", #table_name);
96 pool.select_all(&sql).await
97 }
98
99 pub async fn page(pool: &DbPool, limit: i64, offset: i64) -> Result<Vec<Self>, sqlx::Error> {
100 let sql = format!("SELECT * FROM {} LIMIT {} OFFSET {}", #table_name, limit, offset);
101 pool.select_page(&sql).await
102 }
103
104 pub async fn count(pool: &DbPool) -> Result<i64, sqlx::Error> {
105 let sql = format!("SELECT COUNT(*) FROM {}", #table_name);
106 pool.count(&sql).await
107 }
108 }
109 };
110
111 TokenStream::from(expanded)
112}