Skip to main content

spring_batch_rs/item/rdbc/
sqlite_reader.rs

1use std::cell::{Cell, RefCell};
2
3use sqlx::{sqlite::SqliteRow, Execute, FromRow, Pool, QueryBuilder, Sqlite};
4
5use super::reader_common::{calculate_page_index, should_load_page};
6use crate::core::item::{ItemReader, ItemReaderResult};
7use crate::BatchError;
8
9/// SQLite RDBC Item Reader for batch processing
10///
11/// # Construction
12///
13/// This reader can only be created through `RdbcItemReaderBuilder`.
14/// Direct construction is not available to ensure proper configuration.
15pub struct SqliteRdbcItemReader<'a, I>
16where
17    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
18{
19    pub(crate) pool: Pool<Sqlite>,
20    pub(crate) query: &'a str,
21    pub(crate) page_size: Option<i32>,
22    pub(crate) offset: Cell<i32>,
23    pub(crate) buffer: RefCell<Vec<I>>,
24}
25
26impl<'a, I> SqliteRdbcItemReader<'a, I>
27where
28    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
29{
30    /// Creates a new SqliteRdbcItemReader with the specified parameters
31    ///
32    /// This constructor is only accessible within the crate to enforce the use
33    /// of `RdbcItemReaderBuilder` for creating reader instances.
34    pub(crate) fn new(pool: Pool<Sqlite>, query: &'a str, page_size: Option<i32>) -> Self {
35        Self {
36            pool,
37            query,
38            page_size,
39            offset: Cell::new(0),
40            buffer: RefCell::new(vec![]),
41        }
42    }
43
44    /// Reads a page of data from the database and stores it in the internal buffer.
45    ///
46    /// # Errors
47    ///
48    /// Returns [`BatchError::ItemReader`] if the database query fails.
49    fn read_page(&self) -> Result<(), BatchError> {
50        let mut query_builder = QueryBuilder::<Sqlite>::new(self.query);
51
52        if let Some(page_size) = self.page_size {
53            query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
54        }
55
56        let query = query_builder.build();
57
58        let items = tokio::task::block_in_place(|| {
59            tokio::runtime::Handle::current().block_on(async {
60                sqlx::query_as::<_, I>(query.sql())
61                    .fetch_all(&self.pool)
62                    .await
63                    .map_err(|e| BatchError::ItemReader(e.to_string()))
64            })
65        })?;
66
67        self.buffer.borrow_mut().clear();
68        self.buffer.borrow_mut().extend(items);
69        Ok(())
70    }
71}
72
73impl<I> ItemReader<I> for SqliteRdbcItemReader<'_, I>
74where
75    for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
76{
77    /// Reads the next item from the SQLite database
78    fn read(&self) -> ItemReaderResult<I> {
79        let index = calculate_page_index(self.offset.get(), self.page_size);
80
81        if should_load_page(index) {
82            self.read_page()?;
83        }
84
85        let buffer = self.buffer.borrow();
86        let result = buffer.get(index as usize);
87
88        self.offset.set(self.offset.get() + 1);
89
90        Ok(result.cloned())
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::core::item::ItemReader;
98    use sqlx::{FromRow, SqlitePool};
99
100    #[derive(Clone, FromRow)]
101    struct Row {
102        id: i32,
103        name: String,
104    }
105
106    async fn pool_with_rows(rows: &[(i32, &str)]) -> SqlitePool {
107        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
108        sqlx::query("CREATE TABLE items (id INTEGER, name TEXT)")
109            .execute(&pool)
110            .await
111            .unwrap();
112        for (id, name) in rows {
113            sqlx::query("INSERT INTO items (id, name) VALUES (?, ?)")
114                .bind(id)
115                .bind(name)
116                .execute(&pool)
117                .await
118                .unwrap();
119        }
120        pool
121    }
122
123    #[tokio::test(flavor = "multi_thread")]
124    async fn should_start_with_offset_zero_and_empty_buffer() {
125        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
126        let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
127        assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
128        assert!(
129            reader.buffer.borrow().is_empty(),
130            "initial buffer should be empty"
131        );
132        assert_eq!(reader.page_size, None);
133    }
134
135    #[tokio::test(flavor = "multi_thread")]
136    async fn should_return_none_when_table_is_empty() {
137        let pool = pool_with_rows(&[]).await;
138        let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
139        let result = reader.read().unwrap();
140        assert!(result.is_none(), "empty table should yield None");
141    }
142
143    #[tokio::test(flavor = "multi_thread")]
144    async fn should_read_all_items_without_pagination() {
145        let pool = pool_with_rows(&[(1, "alice"), (2, "bob")]).await;
146        let reader =
147            SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items ORDER BY id", None);
148
149        let first = reader.read().unwrap().expect("first item should exist");
150        assert_eq!(first.name, "alice");
151
152        let second = reader.read().unwrap().expect("second item should exist");
153        assert_eq!(second.name, "bob");
154
155        assert!(
156            reader.read().unwrap().is_none(),
157            "should return None after all items"
158        );
159    }
160
161    #[tokio::test(flavor = "multi_thread")]
162    async fn should_advance_offset_on_each_read() {
163        let pool = pool_with_rows(&[(1, "x"), (2, "y")]).await;
164        let reader =
165            SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items ORDER BY id", None);
166
167        assert_eq!(reader.offset.get(), 0);
168        reader.read().unwrap();
169        assert_eq!(
170            reader.offset.get(),
171            1,
172            "offset should increment after each read"
173        );
174        reader.read().unwrap();
175        assert_eq!(reader.offset.get(), 2);
176    }
177
178    #[tokio::test(flavor = "multi_thread")]
179    async fn should_read_all_items_with_pagination() {
180        let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d")]).await;
181        let reader = SqliteRdbcItemReader::<Row>::new(
182            pool,
183            "SELECT id, name FROM items ORDER BY id",
184            Some(2), // page_size = 2
185        );
186
187        let mut count = 0;
188        while reader.read().unwrap().is_some() {
189            count += 1;
190        }
191        assert_eq!(count, 4, "should read all 4 items across 2 pages");
192    }
193
194    #[tokio::test(flavor = "multi_thread")]
195    async fn should_read_single_item() {
196        let pool = pool_with_rows(&[(42, "only")]).await;
197        let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
198
199        let item = reader
200            .read()
201            .unwrap()
202            .expect("should return the single item");
203        assert_eq!(item.id, 42);
204        assert_eq!(item.name, "only");
205        assert!(
206            reader.read().unwrap().is_none(),
207            "should return None after the only item"
208        );
209    }
210}