1use async_trait::async_trait;
7use itertools::Itertools;
8use sqlx::{database::HasArguments, Executor, IntoArguments};
9
10pub use sqlx_plus_macros::Insertable;
11
12pub trait QueryBindExt<'q, DB: sqlx::Database>: Sized {
13 fn bind<T>(self, value: T) -> Self
14 where
15 T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>;
16
17 fn bind_with<T>(self, value: T, bind_fn: impl Fn(Self, T) -> Self) -> Self {
18 bind_fn(self, value)
19 }
20
21 fn bind_multi<T>(self, values: impl IntoIterator<Item = T>) -> Self
22 where
23 T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
24 {
25 values.into_iter().fold(self, |q, v| q.bind(v))
26 }
27
28 fn bind_multi_with<T: 'q>(
29 self,
30 values: impl IntoIterator<Item = &'q T>,
31 bind_fn: impl Fn(Self, &'q T) -> Self,
32 ) -> Self {
33 values.into_iter().fold(self, |q, x| bind_fn(q, x))
34 }
35
36 fn bind_fields<T: Insertable<Database = DB>>(self, value: &'q T) -> Self {
37 value.bind_fields(self)
38 }
39
40 fn bind_multi_fields<T: Insertable<Database = DB> + 'q>(
41 self,
42 values: impl IntoIterator<Item = &'q T>,
43 ) -> Self {
44 self.bind_multi_with(values, |q, v| q.bind_fields(v))
45 }
46}
47
48impl<'q, DB: sqlx::Database> QueryBindExt<'q, DB>
49 for sqlx::query::Query<'q, DB, <DB as HasArguments<'q>>::Arguments>
50{
51 fn bind<T>(self, value: T) -> Self
52 where
53 T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
54 {
55 sqlx::query::Query::bind(self, value)
56 }
57}
58
59impl<'q, DB, O> QueryBindExt<'q, DB>
60 for sqlx::query::QueryAs<'q, DB, O, <DB as HasArguments<'q>>::Arguments>
61where
62 DB: sqlx::Database,
63{
64 fn bind<T>(self, value: T) -> Self
65 where
66 T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67 {
68 sqlx::query::QueryAs::bind(self, value)
69 }
70}
71
72impl<'q, DB, O> QueryBindExt<'q, DB>
73 for sqlx::query::QueryScalar<'q, DB, O, <DB as HasArguments<'q>>::Arguments>
74where
75 DB: sqlx::Database,
76{
77 fn bind<T>(self, value: T) -> Self
78 where
79 T: 'q + Send + sqlx::Encode<'q, DB> + sqlx::Type<DB>,
80 {
81 sqlx::query::QueryScalar::bind(self, value)
82 }
83}
84
85pub trait Insertable: Sized {
86 type Database: sqlx::Database;
87
88 fn table_name() -> &'static str;
89
90 fn insert_columns() -> Vec<&'static str>;
91
92 fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
93 where
94 Q: QueryBindExt<'q, Self::Database>;
95}
96
97impl<T: Insertable + Sync> Insertable for &T {
98 type Database = T::Database;
99
100 fn table_name() -> &'static str {
101 T::table_name()
102 }
103
104 fn insert_columns() -> Vec<&'static str> {
105 T::insert_columns()
106 }
107
108 fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
109 where
110 Q: QueryBindExt<'q, Self::Database>,
111 {
112 (*self).bind_fields(q)
113 }
114}
115
116#[async_trait]
117pub trait Inserter<DB: sqlx::Database>: Sized {
118 async fn insert<T>(self, value: &T) -> anyhow::Result<DB::QueryResult>
119 where
120 T: Insertable<Database = DB> + Sync;
121
122 async fn bulk_insert_with_table_name_and_chunk_size<T>(
123 self,
124 table_name: &str,
125 chunk_size: usize,
126 values: &[T],
127 ) -> anyhow::Result<Vec<DB::QueryResult>>
128 where
129 T: Insertable<Database = DB> + Sync;
130
131 async fn bulk_insert<T>(self, values: &[T]) -> anyhow::Result<Vec<DB::QueryResult>>
132 where
133 T: Insertable<Database = DB> + Sync,
134 {
135 self.bulk_insert_with_table_name(T::table_name(), values)
136 .await
137 }
138
139 async fn bulk_insert_with_table_name<T>(
140 self,
141 table_name: &str,
142 values: &[T],
143 ) -> anyhow::Result<Vec<DB::QueryResult>>
144 where
145 T: Insertable<Database = DB> + Sync,
146 {
147 self.bulk_insert_with_table_name_and_chunk_size(
148 table_name,
149 30000 / T::insert_columns().len(),
150 values,
151 )
152 .await
153 }
154
155 async fn bulk_insert_with_chunk_size<T>(
156 self,
157 chunk_size: usize,
158 values: &[T],
159 ) -> anyhow::Result<Vec<DB::QueryResult>>
160 where
161 T: Insertable<Database = DB> + Sync,
162 {
163 self.bulk_insert_with_table_name_and_chunk_size(T::table_name(), chunk_size, values)
164 .await
165 }
166}
167
168macro_rules! impl_inserter {
169 ( $db:ty ) => {
170 #[async_trait]
171 impl<E> Inserter<$db> for &'_ mut E
172 where
173 E: Send,
174 for<'a> &'a mut E: Executor<'a, Database = $db>,
175 {
176 async fn insert<T>(
177 self,
178 value: &T,
179 ) -> anyhow::Result<<$db as sqlx::Database>::QueryResult>
180 where
181 T: Insertable<Database = $db> + Sync,
182 {
183 Ok(insert(self, value).await?)
184 }
185
186 async fn bulk_insert_with_table_name_and_chunk_size<T>(
187 self,
188 table_name: &str,
189 chunk_size: usize,
190 values: &[T],
191 ) -> anyhow::Result<Vec<<$db as sqlx::Database>::QueryResult>>
192 where
193 T: Insertable<Database = $db> + Sync,
194 {
195 Ok(
196 bulk_insert_with_table_name_and_chunk_size(
197 self, table_name, chunk_size, values,
198 )
199 .await?,
200 )
201 }
202 }
203
204 #[async_trait]
205 impl Inserter<$db> for &'_ sqlx::Pool<$db> {
206 async fn insert<T>(
207 self,
208 value: &T,
209 ) -> anyhow::Result<<$db as sqlx::Database>::QueryResult>
210 where
211 T: Insertable<Database = $db> + Sync,
212 {
213 Ok(self.acquire().await?.insert(value).await?)
214 }
215
216 async fn bulk_insert_with_table_name_and_chunk_size<T>(
217 self,
218 table_name: &str,
219 chunk_size: usize,
220 values: &[T],
221 ) -> anyhow::Result<Vec<<$db as sqlx::Database>::QueryResult>>
222 where
223 T: Insertable<Database = $db> + Sync,
224 {
225 Ok(self
226 .acquire()
227 .await?
228 .bulk_insert_with_table_name_and_chunk_size(table_name, chunk_size, values)
229 .await?)
230 }
231 }
232 };
233}
234
235#[cfg(feature = "sqlite")]
236impl_inserter!(sqlx::Sqlite);
237#[cfg(feature = "mysql")]
238impl_inserter!(sqlx::MySql);
239#[cfg(feature = "postgres")]
240impl_inserter!(sqlx::Postgres);
241#[cfg(feature = "mssql")]
242impl_inserter!(sqlx::Mssql);
243
244pub trait PlaceHolders: sqlx::Database {
245 #[allow(unused_variables)]
247 fn placeholders(num: usize, start_num: Option<usize>) -> String {
248 placeholders(num)
249 }
250
251 #[allow(unused_variables)]
253 fn placeholders_for_bulk_insert_values<I, T>(values: I, start_num: Option<usize>) -> String
254 where
255 I: Iterator<Item = T>,
256 T: Insertable<Database = Self>,
257 {
258 placeholders_for_bulk_insert_values(values)
259 }
260}
261
262#[cfg(feature = "sqlite")]
263impl PlaceHolders for sqlx::Sqlite {}
264
265#[cfg(feature = "mysql")]
266impl PlaceHolders for sqlx::MySql {}
267
268#[cfg(feature = "mssql")]
269impl PlaceHolders for sqlx::Mssql {}
270
271#[cfg(feature = "postgres")]
272impl PlaceHolders for sqlx::Postgres {
273 fn placeholders(num: usize, start_num: Option<usize>) -> String {
274 placeholders_postgres(num, start_num)
275 }
276
277 fn placeholders_for_bulk_insert_values<I, T>(values: I, start_num: Option<usize>) -> String
278 where
279 I: Iterator<Item = T>,
280 T: Insertable<Database = Self>,
281 {
282 placeholders_for_bulk_insert_values_postgres(values, start_num)
283 }
284}
285
286pub fn placeholders(num: usize) -> String {
288 (0..num).map(|_| "?").join(",")
289}
290
291pub fn placeholders_for_bulk_insert_values<I, T>(values: I) -> String
293where
294 I: Iterator<Item = T>,
295 T: Insertable,
296{
297 format!(
298 "({})",
299 values
300 .map(|_| placeholders(T::insert_columns().len()))
301 .join("),(")
302 )
303}
304
305pub fn placeholders_postgres(num: usize, start_num: Option<usize>) -> String {
307 let start_num = start_num.unwrap_or(1);
308
309 if usize::MAX - start_num < num {
310 panic!("num > usize::MAX - start_num");
311 }
312
313 (0..num)
314 .zip(start_num..(start_num + num))
315 .map(|(_, i)| format!("${}", i))
316 .join(",")
317}
318
319pub fn placeholders_for_bulk_insert_values_postgres<'a, I, T>(
321 values: I,
322 start_num: Option<usize>,
323) -> String
324where
325 I: Iterator<Item = T>,
326 T: Insertable,
327{
328 let start_num = start_num.unwrap_or(1);
329
330 format!(
331 "({})",
332 values
333 .enumerate()
334 .map(|(i, _)| {
335 let num_of_fields = T::insert_columns().len();
336 let start_num = start_num + i * num_of_fields;
337 placeholders_postgres(num_of_fields, Some(start_num))
338 })
339 .join("),(")
340 )
341}
342
343async fn insert<T, E, DB>(executor: &mut E, value: &T) -> anyhow::Result<DB::QueryResult>
344where
345 DB: sqlx::Database + PlaceHolders,
346 T: Insertable<Database = DB> + Sync,
347 for<'e> &'e mut E: Executor<'e, Database = DB>,
348 for<'q> <DB as HasArguments<'q>>::Arguments: IntoArguments<'q, DB>,
349{
350 let sql = format!(
351 r#"
352 INSERT INTO {table_name} ({columns}) VALUES ({placeholders})
353 "#,
354 table_name = T::table_name(),
355 columns = T::insert_columns().join(","),
356 placeholders = DB::placeholders(T::insert_columns().len(), None),
357 );
358
359 sqlx::query(&sql)
360 .bind_fields(value)
361 .execute(executor)
362 .await
363 .map_err(From::from)
364}
365
366async fn bulk_insert_with_table_name_and_chunk_size<T, E, DB>(
367 executor: &mut E,
368 table_name: &str,
369 chunk_size: usize,
370 values: &[T],
371) -> anyhow::Result<Vec<DB::QueryResult>>
372where
373 DB: sqlx::Database + PlaceHolders,
374 T: Insertable<Database = DB> + Sync,
375 for<'e> &'e mut E: Executor<'e, Database = DB>,
376 for<'q> <DB as HasArguments<'q>>::Arguments: IntoArguments<'q, DB>,
377{
378 let mut results = Vec::with_capacity(values.len() / chunk_size);
379
380 for chunk in values.chunks(chunk_size) {
381 let sql = format!(
382 r#"
383 INSERT INTO {table_name} ({columns}) VALUES {placeholders}
384 "#,
385 columns = T::insert_columns().join(","),
386 placeholders = DB::placeholders_for_bulk_insert_values(chunk.iter(), None),
387 );
388 let result = sqlx::query(&sql)
389 .bind_multi_fields(chunk)
390 .execute(&mut *executor)
391 .await?;
392
393 results.push(result);
394 }
395
396 Ok(results)
397}