sqlxinsert/
lib.rs

1// lib.rs
2mod common;
3
4extern crate proc_macro;
5use self::proc_macro::TokenStream;
6
7use quote::quote;
8
9use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields};
10
11use crate::common::dollar_values;
12
13/// Create method for inserting struts into Sqlite database
14///
15/// ```rust
16/// # #[tokio::main]
17/// # async fn main() -> sqlx::Result<()>{
18/// #[derive(Default, Debug, sqlx::FromRow, sqlxinsert::SqliteInsert)]
19/// struct Car {
20///     pub car_id: i32,
21///     pub car_name: String,
22/// }
23///
24/// let car = Car {
25///     car_id: 33,
26///     car_name: "Skoda".to_string(),
27/// };
28///
29/// let url = "sqlite::memory:";
30/// let pool = sqlx::sqlite::SqlitePoolOptions::new().connect(url).await.unwrap();
31///
32/// let create_table = "create table cars ( car_id INTEGER PRIMARY KEY, car_name TEXT NOT NULL )";
33/// sqlx::query(create_table).execute(&pool).await.expect("Not possible to execute");
34///
35/// let res = car.insert_raw(&pool, "cars").await.unwrap(); // returning id
36/// # Ok(())
37/// # }
38/// ```
39///
40#[cfg(feature = "sqlite")]
41#[proc_macro_derive(SqliteInsert)]
42pub fn derive_from_struct_sqlite(input: TokenStream) -> TokenStream {
43    let input = parse_macro_input!(input as DeriveInput);
44
45    let fields = match &input.data {
46        Data::Struct(DataStruct {
47            fields: Fields::Named(fields),
48            ..
49        }) => &fields.named,
50        _ => panic!("expected a struct with named fields"),
51    };
52    // COMMON Atrributes
53    let struct_name = &input.ident;
54
55    // INSERT Attributes -> field names
56    let attributes = fields.iter().map(|field| &field.ident);
57    let attributes_vec: Vec<String> = fields
58        .iter()
59        .map(|field| {
60            field
61                .ident
62                .as_ref()
63                .map(ToString::to_string)
64                .unwrap_or_default()
65        })
66        .collect();
67
68    // ( id, name, hostname .. )
69    let columns = attributes_vec.join(",");
70    // ( $1, $2)
71    let dollars = dollar_values(attributes_vec.len());
72
73    // UPDATE Attributes -> field names for
74    let attributes_update = fields.iter().map(|field| &field.ident);
75    // name = $2, hostname = $3
76    let pairs: String = attributes_vec
77        .iter()
78        .enumerate()
79        .skip(1) // Skip the first element
80        .map(|(index, value)| {
81            let number = index + 1; // Start with $2
82            format!("{} = ${}", value, number)
83        })
84        .collect::<Vec<String>>()
85        .join(",");
86
87    TokenStream::from(quote! {
88
89        impl #struct_name {
90            pub fn insert_query(&self, table: &str) -> String
91            {
92                let sqlquery = format!("insert into {} ( {} ) values ( {} )", table, #columns, #dollars);
93                sqlquery
94            }
95
96            pub async fn insert_raw(&self, pool: &sqlx::SqlitePool, table: &str) -> Result<sqlx::sqlite::SqliteQueryResult, sqlx::Error> {
97                let sql = self.insert_query(table);
98                sqlx::query(&sql)
99                    #(
100                        .bind(&self.#attributes)
101                    )*
102                    .execute(pool)
103                    .await
104            }
105
106            pub fn update_query(&self, table: &str) -> String
107            {
108                let sqlquery = format!("update {} set {} where id = $1", table, #pairs);
109                sqlquery
110            }
111
112            pub async fn update_raw(&self, pool: &sqlx::SqlitePool, table: &str) -> Result<sqlx::sqlite::SqliteQueryResult, sqlx::Error> {
113                let sql = self.update_query(table);
114                sqlx::query(&sql)
115                    #(
116                        .bind(&self.#attributes_update)
117                    )*
118                    .execute(pool)
119                    .await
120            }
121        }
122    })
123}
124
125/// Create method for inserting struts into Postgres database
126///
127/// ```rust,ignore
128/// # #[tokio::main]
129/// # async fn main() -> sqlx::Result<()> {
130///
131/// #[derive(Default, Debug, std::cmp::PartialEq, sqlx::FromRow)]
132/// struct Car {
133///     pub id: i32,
134///     pub name: String,
135/// }
136///
137/// #[derive(Default, Debug, sqlx::FromRow, sqlxinsert::PgInsert)]
138/// struct CreateCar {
139///     pub name: String,
140///     pub color: Option<String>,
141/// }
142/// impl CreateCar {
143///     pub fn new<T: Into<String>>(name: T) -> Self {
144///         CreateCar {
145///             name: name.into(),
146///             color: None,
147///         }
148///     }
149/// }
150/// let url = "postgres://user:pass@localhost:5432/test_db";
151/// let pool = sqlx::postgres::PgPoolOptions::new().connect(&url).await.unwrap();
152///
153/// let car_skoda = CreateCar::new("Skoda");
154/// let res: Car = car_skoda.insert::<Car>(pool, "cars").await?;
155/// # Ok(())
156/// # }
157/// ```
158///
159#[cfg(feature = "postgres")]
160#[proc_macro_derive(PgInsert)]
161pub fn derive_from_struct_psql(input: TokenStream) -> TokenStream {
162    let input = parse_macro_input!(input as DeriveInput);
163
164    let fields = match &input.data {
165        Data::Struct(DataStruct {
166            fields: Fields::Named(fields),
167            ..
168        }) => &fields.named,
169        _ => panic!("expected a struct with named fields"),
170    };
171    // COMMON Atrributes
172    let struct_name = &input.ident;
173
174    // INSERT Attributes -> field names
175    let attributes = fields.iter().map(|field| &field.ident);
176    let attributes_ex = fields.iter().map(|field| &field.ident);
177    let attributes_vec: Vec<String> = fields
178        .iter()
179        .map(|field| {
180            field
181                .ident
182                .as_ref()
183                .map(ToString::to_string)
184                .unwrap_or_default()
185        })
186        .collect();
187
188    // ( id, name, hostname .. )
189    let columns = attributes_vec.join(",");
190    // ( $1, $2)
191    let dollars = dollar_values(attributes_vec.len());
192
193    // UPDATE Attributes -> field names for
194    let attributes_update = fields.iter().map(|field| &field.ident);
195    let attributes_update_ex = fields.iter().map(|field| &field.ident);
196    // name = $2, hostname = $3
197    let pairs: String = attributes_vec
198        .iter()
199        .enumerate()
200        .skip(1) // Skip the first element
201        .map(|(index, value)| {
202            let number = index + 1; // Start with $2
203            format!("{} = ${}", value, number)
204        })
205        .collect::<Vec<String>>()
206        .join(",");
207
208    TokenStream::from(quote! {
209        impl #struct_name {
210            fn insert_query(&self, table: &str) -> String
211            {
212                let sqlquery = format!("insert into {} ( {} ) values ( {} ) returning *", table, #columns, #dollars); // self.value_list()); //self.values );
213                sqlquery
214            }
215
216            pub async fn insert<T>(&self, pool: &sqlx::PgPool, table: &str) -> sqlx::Result<T>
217            where
218                T: Send,
219                T: for<'c> sqlx::FromRow<'c, sqlx::postgres::PgRow>,
220                T: std::marker::Unpin
221            {
222                let sql = self.insert_query(table);
223
224                // let mut pool = pool;
225                let res: T = sqlx::query_as::<_,T>(&sql)
226                #(
227                    .bind(&self.#attributes) //         let #field_name: #field_type = Default::default();
228                )*
229                    .fetch_one(pool)
230                    .await?;
231
232                Ok(res)
233            }
234
235            pub async fn insert_ex<'e,E>(&self, executor: E, table: &str) -> sqlx::Result<()>
236            where
237                E: sqlx::Executor<'e,Database = sqlx::Postgres>
238            {
239                let sql = self.insert_query(table);
240
241                // let mut pool = pool;
242                sqlx::query(&sql)
243                #(
244                    .bind(&self.#attributes_ex) //         let #field_name: #field_type = Default::default();
245                )*
246                    .execute(executor)
247                    .await?;
248
249                Ok(())
250            }
251
252            fn update_query(&self, table: &str) -> String
253            {
254                let sqlquery = format!("update {} set {} where id = $1 returning *", table, #pairs);
255                sqlquery
256            }
257
258            pub async fn update<T>(&self, pool: &sqlx::PgPool, table: &str) -> sqlx::Result<T>
259            where
260                T: Send,
261                T: for<'c> sqlx::FromRow<'c, sqlx::postgres::PgRow>,
262                T: std::marker::Unpin
263            {
264                let sql = self.update_query(table);
265
266                // let mut pool = pool;
267                let res: T = sqlx::query_as::<_,T>(&sql)
268                #(
269                    .bind(&self.#attributes_update)//         let #field_name: #field_type = Default::default();
270                )*
271                    .fetch_one(pool)
272                    .await?;
273
274                Ok(res)
275            }
276
277
278            pub async fn update_ex<'e,E>(&self, executor: E, table: &str) -> sqlx::Result<()>
279            where
280                E: sqlx::Executor<'e,Database = sqlx::Postgres>
281            {
282                let sql = self.update_query(table);
283
284                sqlx::query(&sql)
285                #(
286                    .bind(&self.#attributes_update_ex)
287                )*
288                    .execute(executor)
289                    .await?;
290
291                Ok(())
292            }
293        }
294    })
295}