1use crate::builder::query_builder::{BindValue, QueryBuilder};
2use crate::database_info::DatabaseInfo;
3use crate::error::{Result, SqlxPlusError};
4use crate::traits::Model;
5use sqlx::{Database, Row};
6
7pub type Id = i64;
9
10#[macro_export]
13macro_rules! apply_bind_value {
14 ($query:expr, $bind:expr) => {
15 match $bind {
16 $crate::builder::query_builder::BindValue::String(s) => {
17 $query = $query.bind(s);
18 }
19 $crate::builder::query_builder::BindValue::Int64(i) => {
20 $query = $query.bind(i);
21 }
22 $crate::builder::query_builder::BindValue::Int32(i) => {
23 $query = $query.bind(i);
24 }
25 $crate::builder::query_builder::BindValue::Int16(i) => {
26 $query = $query.bind(i);
27 }
28 $crate::builder::query_builder::BindValue::Int8(ref i) => {
32 $query = $query.bind(*i as i16);
34 }
35 $crate::builder::query_builder::BindValue::UInt64(ref i) => {
36 $query = $query.bind(*i as i64);
38 }
39 $crate::builder::query_builder::BindValue::UInt32(ref i) => {
40 $query = $query.bind(*i as i64);
42 }
43 $crate::builder::query_builder::BindValue::UInt16(ref i) => {
44 $query = $query.bind(*i as i32);
46 }
47 $crate::builder::query_builder::BindValue::UInt8(ref i) => {
48 $query = $query.bind(*i as i16);
50 }
51 $crate::builder::query_builder::BindValue::Float64(f) => {
52 $query = $query.bind(f);
53 }
54 $crate::builder::query_builder::BindValue::Float32(f) => {
55 $query = $query.bind(f);
56 }
57 $crate::builder::query_builder::BindValue::Bool(b) => {
58 $query = $query.bind(b);
59 }
60 $crate::builder::query_builder::BindValue::Bytes(b) => {
61 $query = $query.bind(b);
62 }
63 $crate::builder::query_builder::BindValue::Null => {
64 $query = $query.bind(Option::<String>::None);
65 }
66 }
67 };
68}
69
70fn apply_binds_to_query_generic<'q, DB>(
88 mut query: sqlx::query::Query<'q, DB, DB::Arguments<'q>>,
89 binds: &'q [BindValue],
90) -> sqlx::query::Query<'q, DB, DB::Arguments<'q>>
91where
92 DB: Database + DatabaseInfo,
93 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
94 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
97 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
98 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
99 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
100 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
101 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
102 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
103 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
104 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
105{
106 for bind in binds {
107 crate::apply_bind_value!(query, bind);
108 }
109 query
110}
111
112fn apply_binds_to_query_as_generic<'q, DB, M>(
131 mut query: sqlx::query::QueryAs<'q, DB, M, DB::Arguments<'q>>,
132 binds: &'q [BindValue],
133) -> sqlx::query::QueryAs<'q, DB, M, DB::Arguments<'q>>
134where
135 DB: Database + DatabaseInfo,
136 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
137 M: for<'r> sqlx::FromRow<'r, DB::Row>,
138 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
141 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
142 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
143 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
144 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
145 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
146 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
147 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
148 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
149{
150 for bind in binds {
151 crate::apply_bind_value!(query, bind);
152 }
153 query
154}
155
156#[derive(Debug, Clone)]
158pub struct Page<T> {
159 pub items: Vec<T>,
160 pub total: i64,
161 pub page: u32,
162 pub size: u32,
163 pub pages: u32,
164}
165
166impl<T> Page<T> {
167 pub fn new(items: Vec<T>, total: i64, page: u32, size: u32) -> Self {
168 let pages = if size > 0 {
169 ((total as u64 + size as u64 - 1) / size as u64) as u32
170 } else {
171 0
172 };
173 Self {
174 items,
175 total,
176 page,
177 size,
178 pages,
179 }
180 }
181}
182
183pub async fn find_by_id<'e, 'c: 'e, DB, M, E>(
217 executor: E,
218 id: impl for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB> + Send + Sync,
219) -> Result<Option<M>>
220where
221 DB: Database + DatabaseInfo,
222 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
223 M: Model + for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
224 E: sqlx::Executor<'c, Database = DB> + Send,
225{
226 let escaped_table = DB::escape_identifier(M::TABLE);
228 let escaped_pk = DB::escape_identifier(M::PK);
229 let placeholder = DB::placeholder(0);
230
231 let sql_str = if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
232 let escaped_field = DB::escape_identifier(soft_delete_field);
233 format!(
234 "SELECT * FROM {} WHERE {} = {} AND {} = 0",
235 escaped_table, escaped_pk, placeholder, escaped_field
236 )
237 } else {
238 format!(
239 "SELECT * FROM {} WHERE {} = {}",
240 escaped_table, escaped_pk, placeholder
241 )
242 };
243
244 match sqlx::query(&sql_str)
246 .bind(id)
247 .fetch_optional(executor)
248 .await?
249 {
250 Some(row) => Ok(Some(sqlx::FromRow::from_row(&row).map_err(|e| {
251 SqlxPlusError::DatabaseError(sqlx::Error::Decode(
252 Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>
253 ))
254 })?)),
255 None => Ok(None),
256 }
257}
258
259pub async fn find_by_ids<'e, 'c: 'e, DB, M, I, E>(executor: E, ids: I) -> Result<Vec<M>>
298where
299 DB: Database + DatabaseInfo,
300 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
301 M: Model + for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
302 I: IntoIterator + Send,
303 I::Item: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB> + Send + Sync + Clone,
304 E: sqlx::Executor<'c, Database = DB> + Send,
305{
306 let ids_vec: Vec<_> = ids.into_iter().collect();
307 if ids_vec.is_empty() {
308 return Ok(Vec::new());
309 }
310
311 let escaped_table = DB::escape_identifier(M::TABLE);
313 let escaped_pk = DB::escape_identifier(M::PK);
314
315 let placeholders: Vec<String> = (0..ids_vec.len()).map(|i| DB::placeholder(i)).collect();
317 let placeholders_str = placeholders.join(", ");
318
319 let mut sql_str = format!(
320 "SELECT * FROM {} WHERE {} IN ({})",
321 escaped_table, escaped_pk, placeholders_str
322 );
323
324 if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
325 let escaped_field = DB::escape_identifier(soft_delete_field);
326 sql_str.push_str(&format!(" AND {} = 0", escaped_field));
327 }
328
329 let mut query = sqlx::query_as::<DB, M>(&sql_str);
331 for id in &ids_vec {
332 query = query.bind(id.clone());
333 }
334 query
335 .fetch_all(executor)
336 .await
337 .map_err(|e| SqlxPlusError::DatabaseError(e))
338}
339
340pub async fn find_one<'e, 'c: 'e, DB, M, E>(executor: E, builder: QueryBuilder) -> Result<Option<M>>
381where
382 DB: Database + DatabaseInfo,
383 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
384 M: Model + for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
385 E: sqlx::Executor<'c, Database = DB> + Send,
386 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
390 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
391 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
392 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
393 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
394 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
395 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
396 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
397 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
398{
399 let driver = DB::get_driver();
401 let escaped_table = DB::escape_identifier(M::TABLE);
402
403 let mut query_builder = builder;
405 query_builder = query_builder.with_base_sql(format!("SELECT * FROM {}", escaped_table));
406
407 if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
409 query_builder = query_builder.and_eq(soft_delete_field, 0);
410 }
411
412 let mut sql = query_builder.into_sql(driver);
414 sql.push_str(" LIMIT 1");
415
416 let binds = query_builder.binds().to_vec();
417 let query = sqlx::query_as::<DB, M>(&sql);
418 let query = apply_binds_to_query_as_generic(query, &binds);
419
420 query
421 .fetch_optional(executor)
422 .await
423 .map_err(|e| SqlxPlusError::DatabaseError(e))
424}
425
426pub async fn find_all<'e, 'c: 'e, DB, M, E>(
465 executor: E,
466 builder: Option<QueryBuilder>,
467) -> Result<Vec<M>>
468where
469 DB: Database + DatabaseInfo,
470 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
471 M: Model + for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
472 E: sqlx::Executor<'c, Database = DB> + Send,
473 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
477 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
478 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
479 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
480 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
481 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
482 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
483 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
484 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
485{
486 let driver = DB::get_driver();
488 let escaped_table = DB::escape_identifier(M::TABLE);
489
490 let mut query_builder =
492 builder.unwrap_or_else(|| QueryBuilder::new(format!("SELECT * FROM {}", escaped_table)));
493 query_builder = query_builder.with_base_sql(format!("SELECT * FROM {}", escaped_table));
494
495 if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
497 query_builder = query_builder.and_eq(soft_delete_field, 0);
498 }
499
500 let mut sql = query_builder.into_sql(driver);
502 sql.push_str(" LIMIT 1000");
503
504 let binds = query_builder.binds().to_vec();
505 let query = sqlx::query_as::<DB, M>(&sql);
506 let query = apply_binds_to_query_as_generic(query, &binds);
507
508 query
509 .fetch_all(executor)
510 .await
511 .map_err(|e| SqlxPlusError::DatabaseError(e))
512}
513
514pub async fn count<'e, 'c: 'e, DB, M, E>(executor: E, builder: QueryBuilder) -> Result<u64>
555where
556 DB: Database + DatabaseInfo,
557 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
558 M: Model,
559 E: sqlx::Executor<'c, Database = DB> + Send,
560 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
566 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB> + for<'r> sqlx::Decode<'r, DB>,
567 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
568 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
569 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
570 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
571 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
572 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
573 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
574 usize: sqlx::ColumnIndex<DB::Row>,
575{
576 let driver = DB::get_driver();
578 let escaped_table = DB::escape_identifier(M::TABLE);
579
580 let mut query_builder = builder;
581 query_builder = query_builder.with_base_sql(format!("SELECT * FROM {}", escaped_table));
582
583 if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
584 query_builder = query_builder.and_eq(soft_delete_field, 0);
585 }
586
587 let binds = query_builder.binds().to_vec();
588 let count_sql = query_builder.into_count_sql(driver);
589 let query = sqlx::query::<DB>(&count_sql);
590 let query = apply_binds_to_query_generic(query, &binds);
591
592 let row = query.fetch_one(executor).await?;
593 let count_value: i64 = row.get(0usize);
595 Ok(count_value as u64)
596}
597
598pub async fn paginate<'e, 'c: 'e, DB, M, E>(
641 executor: E,
642 mut builder: QueryBuilder,
643 page: u32,
644 size: u32,
645) -> Result<Page<M>>
646where
647 DB: Database + DatabaseInfo,
648 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
649 M: Model + for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
650 E: sqlx::Executor<'c, Database = DB> + Send + Clone,
651 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
657 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB> + for<'r> sqlx::Decode<'r, DB>,
658 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
659 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
660 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
661 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
662 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
663 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
664 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
665 usize: sqlx::ColumnIndex<DB::Row>,
666{
667 let offset = ((page as u64).saturating_sub(1) * size as u64) as u32;
668 let driver = DB::get_driver();
669 let escaped_table = DB::escape_identifier(M::TABLE);
670
671 builder = builder.with_base_sql(format!("SELECT * FROM {}", escaped_table));
672
673 if let Some(soft_delete_field) = M::SOFT_DELETE_FIELD {
674 builder = builder.and_eq(soft_delete_field, 0);
675 }
676
677 let binds = builder.binds().to_vec();
678
679 let count_sql = builder.clone().into_count_sql(driver);
681 let count_query = sqlx::query::<DB>(&count_sql);
682 let count_query = apply_binds_to_query_generic(count_query, &binds);
683 let executor_clone = executor.clone();
684 let row = count_query.fetch_one(executor_clone).await?;
685 let total: i64 = row.get(0usize);
686
687 let data_sql = builder.clone().into_paginated_sql(driver, size, offset);
689 let query = sqlx::query_as::<DB, M>(&data_sql);
690 let query = apply_binds_to_query_as_generic(query, &binds);
691 let items = query
692 .fetch_all(executor)
693 .await
694 .map_err(|e| SqlxPlusError::DatabaseError(e))?;
695
696 Ok(Page::new(items, total, page, size))
697}
698
699pub async fn hard_delete_by_id<'e, 'c: 'e, DB, M, E>(
737 executor: E,
738 id: impl for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB> + Send + Sync,
739) -> Result<()>
740where
741 DB: Database + DatabaseInfo,
742 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
743 M: Model,
744 E: sqlx::Executor<'c, Database = DB> + Send,
745{
746 let escaped_table = DB::escape_identifier(M::TABLE);
747 let escaped_pk = DB::escape_identifier(M::PK);
748 let placeholder = DB::placeholder(0);
749 let sql = format!(
750 "DELETE FROM {} WHERE {} = {}",
751 escaped_table, escaped_pk, placeholder
752 );
753 sqlx::query(&sql).bind(id).execute(executor).await?;
754 Ok(())
755}
756
757pub async fn soft_delete_by_id<'e, 'c: 'e, DB, M, E>(
792 executor: E,
793 id: impl for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB> + Send + Sync,
794) -> Result<()>
795where
796 DB: Database + DatabaseInfo,
797 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
798 M: Model,
799 E: sqlx::Executor<'c, Database = DB> + Send,
800{
801 let soft_delete_field = M::SOFT_DELETE_FIELD.ok_or_else(|| {
802 SqlxPlusError::DatabaseError(sqlx::Error::Configuration(
803 format!(
804 "Model {} does not have SOFT_DELETE_FIELD defined",
805 std::any::type_name::<M>()
806 )
807 .into(),
808 ))
809 })?;
810
811 let escaped_table = DB::escape_identifier(M::TABLE);
812 let escaped_pk = DB::escape_identifier(M::PK);
813 let escaped_field = DB::escape_identifier(soft_delete_field);
814 let placeholder = DB::placeholder(0);
815 let sql = format!(
816 "UPDATE {} SET {} = 1 WHERE {} = {}",
817 escaped_table, escaped_field, escaped_pk, placeholder
818 );
819 sqlx::query(&sql).bind(id).execute(executor).await?;
820 Ok(())
821}
822
823pub async fn delete_by_id<'e, 'c: 'e, DB, M, E>(
858 executor: E,
859 id: impl for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB> + Send + Sync,
860) -> Result<()>
861where
862 DB: Database + DatabaseInfo,
863 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
864 M: Model,
865 E: sqlx::Executor<'c, Database = DB> + Send,
866{
867 if M::SOFT_DELETE_FIELD.is_some() {
868 soft_delete_by_id::<DB, M, E>(executor, id).await
869 } else {
870 hard_delete_by_id::<DB, M, E>(executor, id).await
871 }
872}