spring_batch_rs/item/rdbc/
mysql_reader.rs1use std::cell::{Cell, RefCell};
2
3use sqlx::{Execute, FromRow, MySql, Pool, QueryBuilder, mysql::MySqlRow};
4
5use super::reader_common::{calculate_page_index, should_load_page};
6use crate::BatchError;
7use crate::core::item::{ItemReader, ItemReaderResult};
8
9pub struct MySqlRdbcItemReader<'a, I>
24where
25 for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
26{
27 pub(crate) pool: Pool<MySql>,
28 pub(crate) query: &'a str,
29 pub(crate) page_size: Option<i32>,
30 pub(crate) offset: Cell<i32>,
31 pub(crate) buffer: RefCell<Vec<I>>,
32 pub(crate) keyset_column: Option<String>,
33 #[allow(clippy::type_complexity)]
34 pub(crate) keyset_key: Option<Box<dyn Fn(&I) -> String>>,
35 pub(crate) last_cursor: RefCell<Option<String>>,
36}
37
38impl<'a, I> MySqlRdbcItemReader<'a, I>
39where
40 for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
41{
42 #[allow(clippy::type_complexity)]
47 pub fn new(
48 pool: Pool<MySql>,
49 query: &'a str,
50 page_size: Option<i32>,
51 keyset_column: Option<String>,
52 keyset_key: Option<Box<dyn Fn(&I) -> String>>,
53 ) -> Self {
54 Self {
55 pool,
56 query,
57 page_size,
58 offset: Cell::new(0),
59 buffer: RefCell::new(vec![]),
60 keyset_column,
61 keyset_key,
62 last_cursor: RefCell::new(None),
63 }
64 }
65
66 fn read_page(&self) -> Result<(), BatchError> {
77 let mut query_builder = QueryBuilder::<MySql>::new(self.query);
78
79 if let Some(page_size) = self.page_size {
80 if let Some(ref col) = self.keyset_column {
81 let last = self.last_cursor.borrow();
82 if let Some(ref cursor_val) = *last {
83 let escaped = cursor_val.replace('\'', "''");
84 query_builder.push(format!(" WHERE {} > '{}'", col, escaped));
85 }
86 query_builder.push(format!(" ORDER BY {} LIMIT {}", col, page_size));
87 } else {
88 query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
89 }
90 }
91
92 let query = query_builder.build();
93
94 let items = tokio::task::block_in_place(|| {
95 tokio::runtime::Handle::current().block_on(async {
96 sqlx::query_as::<_, I>(query.sql())
97 .fetch_all(&self.pool)
98 .await
99 .map_err(|e| BatchError::ItemReader(e.to_string()))
100 })
101 })?;
102
103 self.buffer.borrow_mut().clear();
104 self.buffer.borrow_mut().extend(items);
105 Ok(())
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use sqlx::MySqlPool;
113
114 #[derive(Clone)]
115 struct Dummy;
116
117 impl<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> for Dummy {
118 fn from_row(_row: &'r sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
119 Ok(Dummy)
120 }
121 }
122
123 fn reader_with_keyset(keyset: bool) -> MySqlRdbcItemReader<'static, Dummy> {
124 let pool = MySqlPool::connect_lazy("mysql://root:root@localhost/test")
125 .expect("lazy pool creation should not fail");
126 let (col, key): (Option<String>, Option<Box<dyn Fn(&Dummy) -> String>>) = if keyset {
127 (
128 Some("id".to_string()),
129 Some(Box::new(|_: &Dummy| "0".to_string())),
130 )
131 } else {
132 (None, None)
133 };
134 MySqlRdbcItemReader::new(pool, "SELECT 1", Some(10), col, key)
135 }
136
137 #[tokio::test(flavor = "multi_thread")]
138 async fn should_initialize_without_keyset() {
139 let reader = reader_with_keyset(false);
140 assert!(reader.keyset_column.is_none(), "no keyset column expected");
141 assert!(reader.keyset_key.is_none(), "no keyset key fn expected");
142 assert!(
143 reader.last_cursor.borrow().is_none(),
144 "cursor must start as None"
145 );
146 assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
147 assert!(
148 reader.buffer.borrow().is_empty(),
149 "buffer should start empty"
150 );
151 assert_eq!(reader.page_size, Some(10));
152 }
153
154 #[tokio::test(flavor = "multi_thread")]
155 async fn should_initialize_with_keyset_column_and_none_cursor() {
156 let reader = reader_with_keyset(true);
157 assert_eq!(
158 reader.keyset_column.as_deref(),
159 Some("id"),
160 "keyset column should be stored"
161 );
162 assert!(
163 reader.keyset_key.is_some(),
164 "keyset key fn should be stored"
165 );
166 assert!(
167 reader.last_cursor.borrow().is_none(),
168 "cursor must start as None before first read"
169 );
170 }
171}
172
173impl<I> ItemReader<I> for MySqlRdbcItemReader<'_, I>
174where
175 for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
176{
177 fn read(&self) -> ItemReaderResult<I> {
178 let index = calculate_page_index(self.offset.get(), self.page_size);
179
180 if should_load_page(index) {
181 self.read_page()?;
182 }
183
184 let result = self.buffer.borrow().get(index as usize).cloned();
185
186 if let (Some(item), Some(key_fn)) = (&result, &self.keyset_key) {
187 *self.last_cursor.borrow_mut() = Some(key_fn(item));
188 }
189
190 self.offset.set(self.offset.get() + 1);
191
192 Ok(result)
193 }
194}