sqlx_utils/utils/
batch.rs1use futures::future::try_join_all;
2use futures::FutureExt;
3use std::future::Future;
4use std::ops::{Deref, DerefMut};
5use tracing::instrument;
6
7pub const DEFAULT_BATCH_SIZE: usize = 256;
8
9#[derive(Debug)]
10pub struct BatchOperator<T, const N: usize = DEFAULT_BATCH_SIZE>(Vec<T>);
11
12impl<T, const N: usize> Deref for BatchOperator<T, N> {
13 type Target = Vec<T>;
14 fn deref(&self) -> &Self::Target {
15 &self.0
16 }
17}
18
19impl<T, const N: usize> DerefMut for BatchOperator<T, N> {
20 fn deref_mut(&mut self) -> &mut Self::Target {
21 &mut self.0
22 }
23}
24
25impl<T, const N: usize> Default for BatchOperator<T, N> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<T, const N: usize> BatchOperator<T, N> {
32 pub fn new() -> Self {
33 Self(Vec::with_capacity(N))
34 }
35
36 async fn execute_query_internal<'a>(
37 items: &'a mut Vec<T>,
38 pool: & crate::types::Pool,
39 query: fn(&T) -> crate::types::Query,
40 ) -> crate::Result<()> {
41 if items.is_empty() {
42 return Ok(());
43 }
44
45 let mut tx = pool.begin().await?;
46
47 for item in items.drain(..) {
48 query(&item).execute(&mut *tx).await?;
49 }
50
51 tx.commit().await?;
52
53 Ok(())
54 }
55
56 #[instrument(skip_all, level = "debug")]
57 pub async fn execute_query(
58 iter: impl IntoIterator<Item = T>,
59 pool: &crate::types::Pool,
60 query: fn(&T) -> crate::types::Query,
61 ) -> crate::Result<()> {
62 let mut buf = Self::new();
63
64 for item in iter {
65 buf.push(item);
66
67 if buf.len() == N {
68 Self::execute_query_internal(&mut buf.0, pool, query).await?;
69 }
70 }
71
72 Self::execute_query_internal(&mut buf.0, pool, query).await?;
73
74 Ok(())
75 }
76
77 #[instrument(skip_all, level = "debug")]
78 pub async fn execute_batch<F, Fut, E>(
79 iter: impl IntoIterator<Item = T>,
80 worker: F,
81 ) -> Result<(), E>
82 where
83 F: Fn(Vec<T>) -> Fut,
84 Fut: Future<Output = Result<(), E>>,
85 {
86 let mut buf = Self::new();
87 let mut futures = Vec::new();
88
89 for item in iter {
90 buf.push(item);
91 if buf.len() == N {
92 futures.push(worker(buf.drain(..).collect()));
93 }
94 }
95
96 if !buf.is_empty() {
97 futures.push(worker(buf.drain(..).collect()));
98 }
99
100 try_join_all(futures).await?;
101 Ok(())
102 }
103
104 #[instrument(skip_all, level = "debug")]
105 pub async fn partition_execute<F1, F2, Fut1, Fut2, P, E>(
107 iter: impl IntoIterator<Item = T>,
108 predicate: P,
109 worker1: F1,
110 worker2: F2,
111 ) -> Result<(), E>
112 where
113 P: Fn(&T) -> bool,
114 F1: Fn(Vec<T>) -> Fut1,
115 F2: Fn(Vec<T>) -> Fut2,
116 Fut1: Future<Output = Result<(), E>>,
117 Fut2: Future<Output = Result<(), E>>,
118 {
119 let mut buf1 = Self::new();
120 let mut buf2 = Self::new();
121 let mut futures = Vec::new();
122
123 for item in iter {
124 if predicate(&item) {
125 buf1.push(item);
126 if buf1.len() == N {
127 futures.push(worker1(buf1.drain(..).collect()).boxed_local());
128 }
129 } else {
130 buf2.push(item);
131 if buf2.len() == N {
132 futures.push(worker2(buf2.drain(..).collect()).boxed_local());
133 }
134 }
135 }
136
137 if !buf1.is_empty() {
138 futures.push(worker1(buf1.drain(..).collect()).boxed_local());
139 }
140 if !buf2.is_empty() {
141 futures.push(worker2(buf2.drain(..).collect()).boxed_local());
142 }
143
144 try_join_all(futures).await?;
145
146 Ok(())
147 }
148}