Skip to main content

spring_batch_rs/item/rdbc/
sqlite_reader.rs

1use std::cell::{Cell, RefCell};
2
3use sqlx::{Execute, FromRow, Pool, QueryBuilder, Sqlite, sqlite::SqliteRow};
4
5use super::reader_common::{calculate_page_index, should_load_page};
6use crate::BatchError;
7use crate::core::item::{ItemReader, ItemReaderResult};
8
9/// SQLite RDBC Item Reader for batch processing.
10///
11/// Supports LIMIT/OFFSET pagination (default) and keyset pagination
12/// (enabled via [`RdbcItemReaderBuilder::with_keyset`]).
13///
14/// # Construction
15///
16/// Use [`RdbcItemReaderBuilder`] — direct construction is not available.
17pub struct SqliteRdbcItemReader<'a, I>
18where
19    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
20{
21    pub(crate) pool: Pool<Sqlite>,
22    pub(crate) query: &'a str,
23    pub(crate) page_size: Option<i32>,
24    pub(crate) offset: Cell<i32>,
25    pub(crate) buffer: RefCell<Vec<I>>,
26    pub(crate) keyset_column: Option<String>,
27    #[allow(clippy::type_complexity)]
28    pub(crate) keyset_key: Option<Box<dyn Fn(&I) -> String>>,
29    pub(crate) last_cursor: RefCell<Option<String>>,
30}
31
32impl<'a, I> SqliteRdbcItemReader<'a, I>
33where
34    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
35{
36    /// Creates a new SqliteRdbcItemReader with the specified parameters
37    ///
38    /// This constructor is only accessible within the crate to enforce the use
39    /// of `RdbcItemReaderBuilder` for creating reader instances.
40    #[allow(clippy::type_complexity)]
41    pub(crate) fn new(
42        pool: Pool<Sqlite>,
43        query: &'a str,
44        page_size: Option<i32>,
45        keyset_column: Option<String>,
46        keyset_key: Option<Box<dyn Fn(&I) -> String>>,
47    ) -> Self {
48        Self {
49            pool,
50            query,
51            page_size,
52            offset: Cell::new(0),
53            buffer: RefCell::new(vec![]),
54            keyset_column,
55            keyset_key,
56            last_cursor: RefCell::new(None),
57        }
58    }
59
60    /// Fetches the next page from the database into the internal buffer.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`BatchError::ItemReader`] if the query fails.
65    fn read_page(&self) -> Result<(), BatchError> {
66        let mut query_builder = QueryBuilder::<Sqlite>::new(self.query);
67
68        if let Some(page_size) = self.page_size {
69            if let Some(ref col) = self.keyset_column {
70                let last = self.last_cursor.borrow();
71                if let Some(ref cursor_val) = *last {
72                    let escaped = cursor_val.replace('\'', "''");
73                    query_builder.push(format!(" WHERE {} > '{}'", col, escaped));
74                }
75                query_builder.push(format!(" ORDER BY {} LIMIT {}", col, page_size));
76            } else {
77                query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
78            }
79        }
80
81        let query = query_builder.build();
82
83        let items = tokio::task::block_in_place(|| {
84            tokio::runtime::Handle::current().block_on(async {
85                sqlx::query_as::<_, I>(query.sql())
86                    .fetch_all(&self.pool)
87                    .await
88                    .map_err(|e| BatchError::ItemReader(e.to_string()))
89            })
90        })?;
91
92        self.buffer.borrow_mut().clear();
93        self.buffer.borrow_mut().extend(items);
94        Ok(())
95    }
96}
97
98impl<I> ItemReader<I> for SqliteRdbcItemReader<'_, I>
99where
100    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
101{
102    fn read(&self) -> ItemReaderResult<I> {
103        let index = calculate_page_index(self.offset.get(), self.page_size);
104
105        if should_load_page(index) {
106            self.read_page()?;
107        }
108
109        let result = self.buffer.borrow().get(index as usize).cloned();
110
111        if let (Some(item), Some(key_fn)) = (&result, &self.keyset_key) {
112            *self.last_cursor.borrow_mut() = Some(key_fn(item));
113        }
114
115        self.offset.set(self.offset.get() + 1);
116
117        Ok(result)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::core::item::ItemReader;
125    use sqlx::{FromRow, SqlitePool};
126
127    #[derive(Clone, FromRow)]
128    struct Row {
129        id: i32,
130        name: String,
131    }
132
133    async fn pool_with_rows(rows: &[(i32, &str)]) -> SqlitePool {
134        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
135        sqlx::query("CREATE TABLE items (id INTEGER, name TEXT)")
136            .execute(&pool)
137            .await
138            .unwrap();
139        for (id, name) in rows {
140            sqlx::query("INSERT INTO items (id, name) VALUES (?, ?)")
141                .bind(id)
142                .bind(name)
143                .execute(&pool)
144                .await
145                .unwrap();
146        }
147        pool
148    }
149
150    fn make_reader(
151        pool: SqlitePool,
152        query: &str,
153        page_size: Option<i32>,
154    ) -> SqliteRdbcItemReader<'_, Row> {
155        SqliteRdbcItemReader::<Row>::new(pool, query, page_size, None, None)
156    }
157
158    #[tokio::test(flavor = "multi_thread")]
159    async fn should_start_with_offset_zero_and_empty_buffer() {
160        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
161        let reader = make_reader(pool, "SELECT id, name FROM items", None);
162        assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
163        assert!(
164            reader.buffer.borrow().is_empty(),
165            "initial buffer should be empty"
166        );
167        assert_eq!(reader.page_size, None);
168    }
169
170    #[tokio::test(flavor = "multi_thread")]
171    async fn should_return_none_when_table_is_empty() {
172        let pool = pool_with_rows(&[]).await;
173        let reader = make_reader(pool, "SELECT id, name FROM items", None);
174        let result = reader.read().unwrap();
175        assert!(result.is_none(), "empty table should yield None");
176    }
177
178    #[tokio::test(flavor = "multi_thread")]
179    async fn should_read_all_items_without_pagination() {
180        let pool = pool_with_rows(&[(1, "alice"), (2, "bob")]).await;
181        let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", None);
182
183        let first = reader.read().unwrap().expect("first item should exist");
184        assert_eq!(first.name, "alice");
185
186        let second = reader.read().unwrap().expect("second item should exist");
187        assert_eq!(second.name, "bob");
188
189        assert!(
190            reader.read().unwrap().is_none(),
191            "should return None after all items"
192        );
193    }
194
195    #[tokio::test(flavor = "multi_thread")]
196    async fn should_advance_offset_on_each_read() {
197        let pool = pool_with_rows(&[(1, "x"), (2, "y")]).await;
198        let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", None);
199
200        assert_eq!(reader.offset.get(), 0);
201        reader.read().unwrap();
202        assert_eq!(
203            reader.offset.get(),
204            1,
205            "offset should increment after each read"
206        );
207        reader.read().unwrap();
208        assert_eq!(reader.offset.get(), 2);
209    }
210
211    #[tokio::test(flavor = "multi_thread")]
212    async fn should_read_all_items_with_pagination() {
213        let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d")]).await;
214        let reader = make_reader(pool, "SELECT id, name FROM items ORDER BY id", Some(2));
215
216        let mut count = 0;
217        while reader.read().unwrap().is_some() {
218            count += 1;
219        }
220        assert_eq!(count, 4, "should read all 4 items across 2 pages");
221    }
222
223    #[tokio::test(flavor = "multi_thread")]
224    async fn should_read_single_item() {
225        let pool = pool_with_rows(&[(42, "only")]).await;
226        let reader = make_reader(pool, "SELECT id, name FROM items", None);
227
228        let item = reader
229            .read()
230            .unwrap()
231            .expect("should return the single item");
232        assert_eq!(item.id, 42);
233        assert_eq!(item.name, "only");
234        assert!(
235            reader.read().unwrap().is_none(),
236            "should return None after the only item"
237        );
238    }
239
240    #[tokio::test(flavor = "multi_thread")]
241    async fn should_read_all_items_with_keyset_pagination() {
242        let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")]).await;
243        let reader = SqliteRdbcItemReader::<Row>::new(
244            pool,
245            "SELECT id, name FROM items",
246            Some(2),
247            Some("id".to_string()),
248            Some(Box::new(|r: &Row| r.id.to_string())),
249        );
250
251        let mut names = vec![];
252        while let Some(item) = reader.read().unwrap() {
253            names.push(item.name.clone());
254        }
255        assert_eq!(
256            names,
257            vec!["a", "b", "c", "d", "e"],
258            "keyset should return all items in order"
259        );
260    }
261
262    #[tokio::test(flavor = "multi_thread")]
263    async fn should_update_last_cursor_after_each_read_with_keyset() {
264        let pool = pool_with_rows(&[(10, "x"), (20, "y")]).await;
265        let reader = SqliteRdbcItemReader::<Row>::new(
266            pool,
267            "SELECT id, name FROM items",
268            Some(2),
269            Some("id".to_string()),
270            Some(Box::new(|r: &Row| r.id.to_string())),
271        );
272
273        assert!(
274            reader.last_cursor.borrow().is_none(),
275            "cursor should be None before first read"
276        );
277        reader.read().unwrap();
278        assert_eq!(
279            reader.last_cursor.borrow().as_deref(),
280            Some("10"),
281            "cursor should be updated after first read"
282        );
283        reader.read().unwrap();
284        assert_eq!(
285            reader.last_cursor.borrow().as_deref(),
286            Some("20"),
287            "cursor should reflect last read item"
288        );
289    }
290
291    #[tokio::test(flavor = "multi_thread")]
292    async fn should_return_none_for_empty_table_with_keyset() {
293        let pool = pool_with_rows(&[]).await;
294        let reader = SqliteRdbcItemReader::<Row>::new(
295            pool,
296            "SELECT id, name FROM items",
297            Some(2),
298            Some("id".to_string()),
299            Some(Box::new(|r: &Row| r.id.to_string())),
300        );
301        assert!(
302            reader.read().unwrap().is_none(),
303            "empty table should yield None with keyset"
304        );
305    }
306}