sqlx_plus/
lib.rs

1//! # sqlx-plus
2//!
3//! Please refer [README](https://github.com/sifyfy/sqlx-plus).
4//!
5
6use async_trait::async_trait;
7use itertools::Itertools;
8use sqlx::{database::HasArguments, Executor, IntoArguments};
9
10pub use sqlx_plus_macros::Insertable;
11
12pub trait QueryBindExt<'q, DB: sqlx::Database>: Sized {
13    fn bind<T>(self, value: T) -> Self
14    where
15        T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>;
16
17    fn bind_with<T>(self, value: T, bind_fn: impl Fn(Self, T) -> Self) -> Self {
18        bind_fn(self, value)
19    }
20
21    fn bind_multi<T>(self, values: impl IntoIterator<Item = T>) -> Self
22    where
23        T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
24    {
25        values.into_iter().fold(self, |q, v| q.bind(v))
26    }
27
28    fn bind_multi_with<T: 'q>(
29        self,
30        values: impl IntoIterator<Item = &'q T>,
31        bind_fn: impl Fn(Self, &'q T) -> Self,
32    ) -> Self {
33        values.into_iter().fold(self, |q, x| bind_fn(q, x))
34    }
35
36    fn bind_fields<T: Insertable<Database = DB>>(self, value: &'q T) -> Self {
37        value.bind_fields(self)
38    }
39
40    fn bind_multi_fields<T: Insertable<Database = DB> + 'q>(
41        self,
42        values: impl IntoIterator<Item = &'q T>,
43    ) -> Self {
44        self.bind_multi_with(values, |q, v| q.bind_fields(v))
45    }
46}
47
48impl<'q, DB: sqlx::Database> QueryBindExt<'q, DB>
49    for sqlx::query::Query<'q, DB, <DB as HasArguments<'q>>::Arguments>
50{
51    fn bind<T>(self, value: T) -> Self
52    where
53        T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
54    {
55        sqlx::query::Query::bind(self, value)
56    }
57}
58
59impl<'q, DB, O> QueryBindExt<'q, DB>
60    for sqlx::query::QueryAs<'q, DB, O, <DB as HasArguments<'q>>::Arguments>
61where
62    DB: sqlx::Database,
63{
64    fn bind<T>(self, value: T) -> Self
65    where
66        T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67    {
68        sqlx::query::QueryAs::bind(self, value)
69    }
70}
71
72impl<'q, DB, O> QueryBindExt<'q, DB>
73    for sqlx::query::QueryScalar<'q, DB, O, <DB as HasArguments<'q>>::Arguments>
74where
75    DB: sqlx::Database,
76{
77    fn bind<T>(self, value: T) -> Self
78    where
79        T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
80    {
81        sqlx::query::QueryScalar::bind(self, value)
82    }
83}
84
85pub trait Insertable: Sized {
86    type Database: sqlx::Database;
87
88    fn table_name() -> &'static str;
89
90    fn insert_columns() -> Vec<&'static str>;
91
92    fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
93    where
94        Q: QueryBindExt<'q, Self::Database>;
95}
96
97impl<T: Insertable + Sync> Insertable for &T {
98    type Database = T::Database;
99
100    fn table_name() -> &'static str {
101        T::table_name()
102    }
103
104    fn insert_columns() -> Vec<&'static str> {
105        T::insert_columns()
106    }
107
108    fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
109    where
110        Q: QueryBindExt<'q, Self::Database>,
111    {
112        (*self).bind_fields(q)
113    }
114}
115
116#[async_trait]
117pub trait Inserter<DB: sqlx::Database>: Sized {
118    async fn insert<T>(self, value: &T) -> anyhow::Result<DB::QueryResult>
119    where
120        T: Insertable<Database = DB> + Sync;
121
122    async fn bulk_insert_with_table_name_and_chunk_size<T>(
123        self,
124        table_name: &str,
125        chunk_size: usize,
126        values: &[T],
127    ) -> anyhow::Result<Vec<DB::QueryResult>>
128    where
129        T: Insertable<Database = DB> + Sync;
130
131    async fn bulk_insert<T>(self, values: &[T]) -> anyhow::Result<Vec<DB::QueryResult>>
132    where
133        T: Insertable<Database = DB> + Sync,
134    {
135        self.bulk_insert_with_table_name(T::table_name(), values)
136            .await
137    }
138
139    async fn bulk_insert_with_table_name<T>(
140        self,
141        table_name: &str,
142        values: &[T],
143    ) -> anyhow::Result<Vec<DB::QueryResult>>
144    where
145        T: Insertable<Database = DB> + Sync,
146    {
147        self.bulk_insert_with_table_name_and_chunk_size(
148            table_name,
149            30000 / T::insert_columns().len(),
150            values,
151        )
152        .await
153    }
154
155    async fn bulk_insert_with_chunk_size<T>(
156        self,
157        chunk_size: usize,
158        values: &[T],
159    ) -> anyhow::Result<Vec<DB::QueryResult>>
160    where
161        T: Insertable<Database = DB> + Sync,
162    {
163        self.bulk_insert_with_table_name_and_chunk_size(T::table_name(), chunk_size, values)
164            .await
165    }
166}
167
168macro_rules! impl_inserter {
169    ( $db:ty ) => {
170        #[async_trait]
171        impl<E> Inserter<$db> for &'_ mut E
172        where
173            E: Send,
174            for<'a> &'a mut E: Executor<'a, Database = $db>,
175        {
176            async fn insert<T>(
177                self,
178                value: &T,
179            ) -> anyhow::Result<<$db as sqlx::Database>::QueryResult>
180            where
181                T: Insertable<Database = $db> + Sync,
182            {
183                Ok(insert(self, value).await?)
184            }
185
186            async fn bulk_insert_with_table_name_and_chunk_size<T>(
187                self,
188                table_name: &str,
189                chunk_size: usize,
190                values: &[T],
191            ) -> anyhow::Result<Vec<<$db as sqlx::Database>::QueryResult>>
192            where
193                T: Insertable<Database = $db> + Sync,
194            {
195                Ok(
196                    bulk_insert_with_table_name_and_chunk_size(
197                        self, table_name, chunk_size, values,
198                    )
199                    .await?,
200                )
201            }
202        }
203
204        #[async_trait]
205        impl Inserter<$db> for &'_ sqlx::Pool<$db> {
206            async fn insert<T>(
207                self,
208                value: &T,
209            ) -> anyhow::Result<<$db as sqlx::Database>::QueryResult>
210            where
211                T: Insertable<Database = $db> + Sync,
212            {
213                Ok(self.acquire().await?.insert(value).await?)
214            }
215
216            async fn bulk_insert_with_table_name_and_chunk_size<T>(
217                self,
218                table_name: &str,
219                chunk_size: usize,
220                values: &[T],
221            ) -> anyhow::Result<Vec<<$db as sqlx::Database>::QueryResult>>
222            where
223                T: Insertable<Database = $db> + Sync,
224            {
225                Ok(self
226                    .acquire()
227                    .await?
228                    .bulk_insert_with_table_name_and_chunk_size(table_name, chunk_size, values)
229                    .await?)
230            }
231        }
232    };
233}
234
235#[cfg(feature = "sqlite")]
236impl_inserter!(sqlx::Sqlite);
237#[cfg(feature = "mysql")]
238impl_inserter!(sqlx::MySql);
239#[cfg(feature = "postgres")]
240impl_inserter!(sqlx::Postgres);
241#[cfg(feature = "mssql")]
242impl_inserter!(sqlx::Mssql);
243
244pub trait PlaceHolders: sqlx::Database {
245    /// `start_num` is for only PostgreSQL, it is ignored in other RDB.
246    #[allow(unused_variables)]
247    fn placeholders(num: usize, start_num: Option<usize>) -> String {
248        placeholders(num)
249    }
250
251    /// `start_num` is for only PostgreSQL, it is ignored in other RDB.
252    #[allow(unused_variables)]
253    fn placeholders_for_bulk_insert_values<I, T>(values: I, start_num: Option<usize>) -> String
254    where
255        I: Iterator<Item = T>,
256        T: Insertable<Database = Self>,
257    {
258        placeholders_for_bulk_insert_values(values)
259    }
260}
261
262#[cfg(feature = "sqlite")]
263impl PlaceHolders for sqlx::Sqlite {}
264
265#[cfg(feature = "mysql")]
266impl PlaceHolders for sqlx::MySql {}
267
268#[cfg(feature = "mssql")]
269impl PlaceHolders for sqlx::Mssql {}
270
271#[cfg(feature = "postgres")]
272impl PlaceHolders for sqlx::Postgres {
273    fn placeholders(num: usize, start_num: Option<usize>) -> String {
274        placeholders_postgres(num, start_num)
275    }
276
277    fn placeholders_for_bulk_insert_values<I, T>(values: I, start_num: Option<usize>) -> String
278    where
279        I: Iterator<Item = T>,
280        T: Insertable<Database = Self>,
281    {
282        placeholders_for_bulk_insert_values_postgres(values, start_num)
283    }
284}
285
286/// Generate placeholders string like `?, ?, ..., ?`.
287pub fn placeholders(num: usize) -> String {
288    (0..num).map(|_| "?").join(",")
289}
290
291/// Generate placeholders string like `(?, ?, ..., ?), (?, ?, ..., ?), ..., (?, ?, ..., ?)`.
292pub fn placeholders_for_bulk_insert_values<I, T>(values: I) -> String
293where
294    I: Iterator<Item = T>,
295    T: Insertable,
296{
297    format!(
298        "({})",
299        values
300            .map(|_| placeholders(T::insert_columns().len()))
301            .join("),(")
302    )
303}
304
305/// Generate placeholders string like `$1, $2, ..., $n`.
306pub fn placeholders_postgres(num: usize, start_num: Option<usize>) -> String {
307    let start_num = start_num.unwrap_or(1);
308
309    if usize::MAX - start_num < num {
310        panic!("num > usize::MAX - start_num");
311    }
312
313    (0..num)
314        .zip(start_num..(start_num + num))
315        .map(|(_, i)| format!("${}", i))
316        .join(",")
317}
318
319/// Generate placeholders string like `($1, $2, ..., $n), ($o, $p, ..., $q), ..., ($r, $s, ..., $u)`.
320pub fn placeholders_for_bulk_insert_values_postgres<'a, I, T>(
321    values: I,
322    start_num: Option<usize>,
323) -> String
324where
325    I: Iterator<Item = T>,
326    T: Insertable,
327{
328    let start_num = start_num.unwrap_or(1);
329
330    format!(
331        "({})",
332        values
333            .enumerate()
334            .map(|(i, _)| {
335                let num_of_fields = T::insert_columns().len();
336                let start_num = start_num + i * num_of_fields;
337                placeholders_postgres(num_of_fields, Some(start_num))
338            })
339            .join("),(")
340    )
341}
342
343async fn insert<T, E, DB>(executor: &mut E, value: &T) -> anyhow::Result<DB::QueryResult>
344where
345    DB: sqlx::Database + PlaceHolders,
346    T: Insertable<Database = DB> + Sync,
347    for<'e> &'e mut E: Executor<'e, Database = DB>,
348    for<'q> <DB as HasArguments<'q>>::Arguments: IntoArguments<'q, DB>,
349{
350    let sql = format!(
351        r#"
352            INSERT INTO {table_name} ({columns}) VALUES ({placeholders})
353        "#,
354        table_name = T::table_name(),
355        columns = T::insert_columns().join(","),
356        placeholders = DB::placeholders(T::insert_columns().len(), None),
357    );
358
359    sqlx::query(&sql)
360        .bind_fields(value)
361        .execute(executor)
362        .await
363        .map_err(From::from)
364}
365
366async fn bulk_insert_with_table_name_and_chunk_size<T, E, DB>(
367    executor: &mut E,
368    table_name: &str,
369    chunk_size: usize,
370    values: &[T],
371) -> anyhow::Result<Vec<DB::QueryResult>>
372where
373    DB: sqlx::Database + PlaceHolders,
374    T: Insertable<Database = DB> + Sync,
375    for<'e> &'e mut E: Executor<'e, Database = DB>,
376    for<'q> <DB as HasArguments<'q>>::Arguments: IntoArguments<'q, DB>,
377{
378    let mut results = Vec::with_capacity(values.len() / chunk_size);
379
380    for chunk in values.chunks(chunk_size) {
381        let sql = format!(
382            r#"
383                    INSERT INTO {table_name} ({columns}) VALUES {placeholders}
384            "#,
385            columns = T::insert_columns().join(","),
386            placeholders = DB::placeholders_for_bulk_insert_values(chunk.iter(), None),
387        );
388        let result = sqlx::query(&sql)
389            .bind_multi_fields(chunk)
390            .execute(&mut *executor)
391            .await?;
392
393        results.push(result);
394    }
395
396    Ok(results)
397}