sqlx_utils/utils/
batch.rs

1use futures::future::try_join_all;
2use futures::FutureExt;
3use std::future::Future;
4use std::ops::{Deref, DerefMut};
5use tracing::instrument;
6
7pub const DEFAULT_BATCH_SIZE: usize = 256;
8
9#[derive(Debug)]
10pub struct BatchOperator<T, const N: usize = DEFAULT_BATCH_SIZE>(Vec<T>);
11
12impl<T, const N: usize> Deref for BatchOperator<T, N> {
13    type Target = Vec<T>;
14    fn deref(&self) -> &Self::Target {
15        &self.0
16    }
17}
18
19impl<T, const N: usize> DerefMut for BatchOperator<T, N> {
20    fn deref_mut(&mut self) -> &mut Self::Target {
21        &mut self.0
22    }
23}
24
25impl<T, const N: usize> Default for BatchOperator<T, N> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<T, const N: usize> BatchOperator<T, N> {
32    pub fn new() -> Self {
33        Self(Vec::with_capacity(N))
34    }
35
36    async fn execute_query_internal<'a>(
37        items: &'a mut Vec<T>,
38        pool: & crate::types::Pool,
39        query: fn(&T) -> crate::types::Query,
40    ) -> crate::Result<()> {
41        if items.is_empty() {
42            return Ok(());
43        }
44
45        let mut tx = pool.begin().await?;
46
47        for item in items.drain(..) {
48            query(&item).execute(&mut *tx).await?;
49        }
50
51        tx.commit().await?;
52
53        Ok(())
54    }
55
56    #[instrument(skip_all, level = "debug")]
57    pub async fn execute_query(
58        iter: impl IntoIterator<Item = T>,
59        pool: &crate::types::Pool,
60        query: fn(&T) -> crate::types::Query,
61    ) -> crate::Result<()> {
62        let mut buf = Self::new();
63
64        for item in iter {
65            buf.push(item);
66
67            if buf.len() == N {
68                Self::execute_query_internal(&mut buf.0, pool, query).await?;
69            }
70        }
71
72        Self::execute_query_internal(&mut buf.0, pool, query).await?;
73
74        Ok(())
75    }
76
77    #[instrument(skip_all, level = "debug")]
78    pub async fn execute_batch<F, Fut, E>(
79        iter: impl IntoIterator<Item = T>,
80        worker: F,
81    ) -> Result<(), E>
82    where
83        F: Fn(Vec<T>) -> Fut,
84        Fut: Future<Output = Result<(), E>>,
85    {
86        let mut buf = Self::new();
87        let mut futures = Vec::new();
88
89        for item in iter {
90            buf.push(item);
91            if buf.len() == N {
92                futures.push(worker(buf.drain(..).collect()));
93            }
94        }
95
96        if !buf.is_empty() {
97            futures.push(worker(buf.drain(..).collect()));
98        }
99
100        try_join_all(futures).await?;
101        Ok(())
102    }
103
104    #[instrument(skip_all, level = "debug")]
105    /// # NOTE: This only works in a non Send context
106    pub async fn partition_execute<F1, F2, Fut1, Fut2, P, E>(
107        iter: impl IntoIterator<Item = T>,
108        predicate: P,
109        worker1: F1,
110        worker2: F2,
111    ) -> Result<(), E>
112    where
113        P: Fn(&T) -> bool,
114        F1: Fn(Vec<T>) -> Fut1,
115        F2: Fn(Vec<T>) -> Fut2,
116        Fut1: Future<Output = Result<(), E>>,
117        Fut2: Future<Output = Result<(), E>>,
118    {
119        let mut buf1 = Self::new();
120        let mut buf2 = Self::new();
121        let mut futures = Vec::new();
122
123        for item in iter {
124            if predicate(&item) {
125                buf1.push(item);
126                if buf1.len() == N {
127                    futures.push(worker1(buf1.drain(..).collect()).boxed_local());
128                }
129            } else {
130                buf2.push(item);
131                if buf2.len() == N {
132                    futures.push(worker2(buf2.drain(..).collect()).boxed_local());
133                }
134            }
135        }
136
137        if !buf1.is_empty() {
138            futures.push(worker1(buf1.drain(..).collect()).boxed_local());
139        }
140        if !buf2.is_empty() {
141            futures.push(worker2(buf2.drain(..).collect()).boxed_local());
142        }
143
144        try_join_all(futures).await?;
145
146        Ok(())
147    }
148}