spring_batch_rs/item/rdbc/
sqlite_reader.rs1use std::cell::{Cell, RefCell};
2
3use sqlx::{Execute, FromRow, Pool, QueryBuilder, Sqlite, sqlite::SqliteRow};
4
5use super::reader_common::{calculate_page_index, should_load_page};
6use crate::BatchError;
7use crate::core::item::{ItemReader, ItemReaderResult};
8
9pub struct SqliteRdbcItemReader<'a, I>
18where
19 for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
20{
21 pub(crate) pool: Pool<Sqlite>,
22 pub(crate) query: &'a str,
23 pub(crate) page_size: Option<i32>,
24 pub(crate) offset: Cell<i32>,
25 pub(crate) buffer: RefCell<Vec<I>>,
26 pub(crate) keyset_column: Option<String>,
27 #[allow(clippy::type_complexity)]
28 pub(crate) keyset_key: Option<Box<dyn Fn(&I) -> String>>,
29 pub(crate) last_cursor: RefCell<Option<String>>,
30}
31
32impl<'a, I> SqliteRdbcItemReader<'a, I>
33where
34 for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
35{
36 #[allow(clippy::type_complexity)]
41 pub(crate) fn new(
42 pool: Pool<Sqlite>,
43 query: &'a str,
44 page_size: Option<i32>,
45 keyset_column: Option<String>,
46 keyset_key: Option<Box<dyn Fn(&I) -> String>>,
47 ) -> Self {
48 Self {
49 pool,
50 query,
51 page_size,
52 offset: Cell::new(0),
53 buffer: RefCell::new(vec![]),
54 keyset_column,
55 keyset_key,
56 last_cursor: RefCell::new(None),
57 }
58 }
59
60 fn read_page(&self) -> Result<(), BatchError> {
66 let mut query_builder = QueryBuilder::<Sqlite>::new(self.query);
67
68 if let Some(page_size) = self.page_size {
69 if let Some(ref col) = self.keyset_column {
70 let last = self.last_cursor.borrow();
71 if let Some(ref cursor_val) = *last {
72 let escaped = cursor_val.replace('\'', "''");
73 query_builder.push(format!(" WHERE {} > '{}'", col, escaped));
74 }
75 query_builder.push(format!(" ORDER BY {} LIMIT {}", col, page_size));
76 } else {
77 query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
78 }
79 }
80
81 let query = query_builder.build();
82
83 let items = tokio::task::block_in_place(|| {
84 tokio::runtime::Handle::current().block_on(async {
85 sqlx::query_as::<_, I>(query.sql())
86 .fetch_all(&self.pool)
87 .await
88 .map_err(|e| BatchError::ItemReader(e.to_string()))
89 })
90 })?;
91
92 self.buffer.borrow_mut().clear();
93 self.buffer.borrow_mut().extend(items);
94 Ok(())
95 }
96}
97
98impl<I> ItemReader<I> for SqliteRdbcItemReader<'_, I>
99where
100 for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
101{
102 fn read(&self) -> ItemReaderResult<I> {
103 let index = calculate_page_index(self.offset.get(), self.page_size);
104
105 if should_load_page(index) {
106 self.read_page()?;
107 }
108
109 let result = self.buffer.borrow().get(index as usize).cloned();
110
111 if let (Some(item), Some(key_fn)) = (&result, &self.keyset_key) {
112 *self.last_cursor.borrow_mut() = Some(key_fn(item));
113 }
114
115 self.offset.set(self.offset.get() + 1);
116
117 Ok(result)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::core::item::ItemReader;
125 use sqlx::{FromRow, SqlitePool};
126
127 #[derive(Clone, FromRow)]
128 struct Row {
129 id: i32,
130 name: String,
131 }
132
133 async fn pool_with_rows(rows: &[(i32, &str)]) -> SqlitePool {
134 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
135 sqlx::query("CREATE TABLE items (id INTEGER, name TEXT)")
136 .execute(&pool)
137 .await
138 .unwrap();
139 for (id, name) in rows {
140 sqlx::query("INSERT INTO items (id, name) VALUES (?, ?)")
141 .bind(id)
142 .bind(name)
143 .execute(&pool)
144 .await
145 .unwrap();
146 }
147 pool
148 }
149
150 fn make_reader(
151 pool: SqlitePool,
152 query: &str,
153 page_size: Option<i32>,
154 ) -> SqliteRdbcItemReader<'_, Row> {
155 SqliteRdbcItemReader::<Row>::new(pool, query, page_size, None, None)
156 }
157
158 #[tokio::test(flavor = "multi_thread")]
159 async fn should_start_with_offset_zero_and_empty_buffer() {
160 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
161 let reader = make_reader(pool, "SELECT id, name FROM items", None);
162 assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
163 assert!(
164 reader.buffer.borrow().is_empty(),
165 "initial buffer should be empty"
166 );
167 assert_eq!(reader.page_size, None);
168 }
169
170 #[tokio::test(flavor = "multi_thread")]
171 async fn should_return_none_when_table_is_empty() {
172 let pool = pool_with_rows(&[]).await;
173 let reader = make_reader(pool, "SELECT id, name FROM items", None);
174 let result = reader.read().unwrap();
175 assert!(result.is_none(), "empty table should yield None");
176 }
177
178 #[tokio::test(flavor = "multi_thread")]
179 async fn should_read_all_items_without_pagination() {
180 let pool = pool_with_rows(&[(1, "alice"), (2, "bob")]).await;
181 let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", None);
182
183 let first = reader.read().unwrap().expect("first item should exist");
184 assert_eq!(first.name, "alice");
185
186 let second = reader.read().unwrap().expect("second item should exist");
187 assert_eq!(second.name, "bob");
188
189 assert!(
190 reader.read().unwrap().is_none(),
191 "should return None after all items"
192 );
193 }
194
195 #[tokio::test(flavor = "multi_thread")]
196 async fn should_advance_offset_on_each_read() {
197 let pool = pool_with_rows(&[(1, "x"), (2, "y")]).await;
198 let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", None);
199
200 assert_eq!(reader.offset.get(), 0);
201 reader.read().unwrap();
202 assert_eq!(
203 reader.offset.get(),
204 1,
205 "offset should increment after each read"
206 );
207 reader.read().unwrap();
208 assert_eq!(reader.offset.get(), 2);
209 }
210
211 #[tokio::test(flavor = "multi_thread")]
212 async fn should_read_all_items_with_pagination() {
213 let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d")]).await;
214 let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", Some(2));
215
216 let mut count = 0;
217 while reader.read().unwrap().is_some() {
218 count += 1;
219 }
220 assert_eq!(count, 4, "should read all 4 items across 2 pages");
221 }
222
223 #[tokio::test(flavor = "multi_thread")]
224 async fn should_read_single_item() {
225 let pool = pool_with_rows(&[(42, "only")]).await;
226 let reader = make_reader(pool, "SELECT id, name FROM items", None);
227
228 let item = reader
229 .read()
230 .unwrap()
231 .expect("should return the single item");
232 assert_eq!(item.id, 42);
233 assert_eq!(item.name, "only");
234 assert!(
235 reader.read().unwrap().is_none(),
236 "should return None after the only item"
237 );
238 }
239
240 #[tokio::test(flavor = "multi_thread")]
241 async fn should_read_all_items_with_keyset_pagination() {
242 let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")]).await;
243 let reader = SqliteRdbcItemReader::<Row>::new(
244 pool,
245 "SELECT id, name FROM items",
246 Some(2),
247 Some("id".to_string()),
248 Some(Box::new(|r: &Row| r.id.to_string())),
249 );
250
251 let mut names = vec![];
252 while let Some(item) = reader.read().unwrap() {
253 names.push(item.name.clone());
254 }
255 assert_eq!(
256 names,
257 vec!["a", "b", "c", "d", "e"],
258 "keyset should return all items in order"
259 );
260 }
261
262 #[tokio::test(flavor = "multi_thread")]
263 async fn should_update_last_cursor_after_each_read_with_keyset() {
264 let pool = pool_with_rows(&[(10, "x"), (20, "y")]).await;
265 let reader = SqliteRdbcItemReader::<Row>::new(
266 pool,
267 "SELECT id, name FROM items",
268 Some(2),
269 Some("id".to_string()),
270 Some(Box::new(|r: &Row| r.id.to_string())),
271 );
272
273 assert!(
274 reader.last_cursor.borrow().is_none(),
275 "cursor should be None before first read"
276 );
277 reader.read().unwrap();
278 assert_eq!(
279 reader.last_cursor.borrow().as_deref(),
280 Some("10"),
281 "cursor should be updated after first read"
282 );
283 reader.read().unwrap();
284 assert_eq!(
285 reader.last_cursor.borrow().as_deref(),
286 Some("20"),
287 "cursor should reflect last read item"
288 );
289 }
290
291 #[tokio::test(flavor = "multi_thread")]
292 async fn should_return_none_for_empty_table_with_keyset() {
293 let pool = pool_with_rows(&[]).await;
294 let reader = SqliteRdbcItemReader::<Row>::new(
295 pool,
296 "SELECT id, name FROM items",
297 Some(2),
298 Some("id".to_string()),
299 Some(Box::new(|r: &Row| r.id.to_string())),
300 );
301 assert!(
302 reader.read().unwrap().is_none(),
303 "empty table should yield None with keyset"
304 );
305 }
306}