spring_batch_rs/item/rdbc/
rdbc_reader.rs

1use std::cell::{Cell, RefCell};
2
3use serde::de::DeserializeOwned;
4use sqlx::{any::AnyRow, Any, Pool, QueryBuilder};
5
6use crate::core::item::{ItemReader, ItemReaderResult};
7
8/// Trait for mapping a database row to a specific type.
9pub trait RdbcRowMapper<T> {
10    /// Maps a database row to the specified type.
11    fn map_row(&self, row: &AnyRow) -> T;
12}
13
14/// A reader for reading items from a relational database using SQLx.
15pub struct RdbcItemReader<'a, T> {
16    pool: &'a Pool<Any>,
17    query: &'a str,
18    page_size: Option<i32>,
19    offset: Cell<i32>,
20    row_mapper: &'a dyn RdbcRowMapper<T>,
21    buffer: RefCell<Vec<T>>,
22}
23
24impl<'a, T> RdbcItemReader<'a, T> {
25    /// Creates a new `RdbcItemReader`.
26    ///
27    /// # Arguments
28    ///
29    /// * `pool` - The database connection pool.
30    /// * `query` - The SQL query to execute.
31    /// * `page_size` - The number of items to read per page.
32    /// * `row_mapper` - The row mapper for mapping database rows to items.
33    ///
34    /// # Returns
35    ///
36    /// A new `RdbcItemReader` instance.
37    pub fn new(
38        pool: &'a Pool<Any>,
39        query: &'a str,
40        page_size: Option<i32>,
41        row_mapper: &'a dyn RdbcRowMapper<T>,
42    ) -> Self {
43        let buffer = if let Some(page_size) = page_size {
44            let buffer_size = page_size.try_into().unwrap_or(1);
45            Vec::with_capacity(buffer_size)
46        } else {
47            Vec::new()
48        };
49
50        Self {
51            pool,
52            query,
53            page_size,
54            offset: Cell::new(0),
55            row_mapper,
56            buffer: RefCell::new(buffer),
57        }
58    }
59
60    /// Reads a page of items from the database.
61    fn read_page(&self) {
62        let mut query_builder = QueryBuilder::new(self.query);
63
64        if self.page_size.is_some() {
65            query_builder.push(format!(
66                " LIMIT {} OFFSET {}",
67                self.page_size.unwrap(),
68                self.offset.get()
69            ));
70        }
71
72        let query = query_builder.build();
73
74        let rows = tokio::task::block_in_place(|| {
75            tokio::runtime::Runtime::new()
76                .unwrap()
77                .block_on(async { query.fetch_all(self.pool).await.unwrap() })
78        });
79
80        self.buffer.borrow_mut().clear();
81
82        rows.iter().for_each(|x| {
83            let item = self.row_mapper.map_row(x);
84            self.buffer.borrow_mut().push(item);
85        });
86    }
87}
88
89impl<'a, T: DeserializeOwned + Clone> ItemReader<T> for RdbcItemReader<'a, T> {
90    /// Reads the next item from the reader.
91    ///
92    /// # Returns
93    ///
94    /// The next item, or `None` if there are no more items.
95    fn read(&self) -> ItemReaderResult<T> {
96        let index = if let Some(page_size) = self.page_size {
97            self.offset.get() % page_size
98        } else {
99            self.offset.get()
100        };
101
102        if index == 0 {
103            self.read_page();
104        }
105
106        let buffer = self.buffer.borrow();
107
108        let result = buffer.get(index as usize);
109
110        self.offset.set(self.offset.get() + 1);
111
112        Ok(result.cloned())
113    }
114}
115
116/// Builder for creating an `RdbcItemReader`.
117#[derive(Default)]
118pub struct RdbcItemReaderBuilder<'a, T> {
119    pool: Option<&'a Pool<Any>>,
120    query: Option<&'a str>,
121    page_size: Option<i32>,
122    row_mapper: Option<&'a dyn RdbcRowMapper<T>>,
123}
124
125impl<'a, T> RdbcItemReaderBuilder<'a, T> {
126    /// Creates a new `RdbcItemReaderBuilder`.
127    pub fn new() -> Self {
128        Self {
129            pool: None,
130            query: None,
131            page_size: None,
132            row_mapper: None,
133        }
134    }
135
136    /// Sets the page size for the reader.
137    ///
138    /// # Arguments
139    ///
140    /// * `page_size` - The number of items to read per page.
141    ///
142    /// # Returns
143    ///
144    /// The updated `RdbcItemReaderBuilder` instance.
145    pub fn page_size(mut self, page_size: i32) -> Self {
146        self.page_size = Some(page_size);
147        self
148    }
149
150    /// Sets the SQL query for the reader.
151    ///
152    /// # Arguments
153    ///
154    /// * `query` - The SQL query to execute.
155    ///
156    /// # Returns
157    ///
158    /// The updated `RdbcItemReaderBuilder` instance.
159    pub fn query(mut self, query: &'a str) -> Self {
160        self.query = Some(query);
161        self
162    }
163
164    /// Sets the database connection pool for the reader.
165    ///
166    /// # Arguments
167    ///
168    /// * `pool` - The database connection pool.
169    ///
170    /// # Returns
171    ///
172    /// The updated `RdbcItemReaderBuilder` instance.
173    pub fn pool(mut self, pool: &'a Pool<Any>) -> Self {
174        self.pool = Some(pool);
175        self
176    }
177
178    /// Sets the row mapper for the reader.
179    ///
180    /// # Arguments
181    ///
182    /// * `row_mapper` - The row mapper for mapping database rows to items.
183    ///
184    /// # Returns
185    ///
186    /// The updated `RdbcItemReaderBuilder` instance.
187    pub fn row_mapper(mut self, row_mapper: &'a dyn RdbcRowMapper<T>) -> Self {
188        self.row_mapper = Some(row_mapper);
189        self
190    }
191
192    /// Builds the `RdbcItemReader` instance.
193    ///
194    /// # Returns
195    ///
196    /// The built `RdbcItemReader` instance.
197    pub fn build(self) -> RdbcItemReader<'a, T> {
198        RdbcItemReader::new(
199            self.pool.unwrap(),
200            self.query.unwrap(),
201            self.page_size,
202            self.row_mapper.unwrap(),
203        )
204    }
205}