Skip to main content

spring_batch_rs/item/rdbc/
postgres_reader.rs

1use std::cell::{Cell, RefCell};
2
3use sqlx::{Execute, FromRow, Pool, Postgres, QueryBuilder, postgres::PgRow};
4
5use super::reader_common::{calculate_page_index, should_load_page};
6use crate::BatchError;
7use crate::core::item::{ItemReader, ItemReaderResult};
8
9/// PostgreSQL RDBC Item Reader for batch processing.
10///
11/// Supports two pagination strategies:
12///
13/// - **LIMIT/OFFSET** (default): simple but degrades at large offsets — use for small datasets.
14/// - **Keyset pagination** (recommended for large datasets): uses `WHERE col > :last ORDER BY col LIMIT n`,
15///   O(log n) per page regardless of dataset size. Enable with [`RdbcItemReaderBuilder::with_keyset`].
16///
17/// # Type Parameters
18///
19/// * `I` - Must implement `FromRow<PgRow> + Send + Unpin + Clone`.
20///
21/// # Construction
22///
23/// Use [`RdbcItemReaderBuilder`] — direct construction is not available.
24pub struct PostgresRdbcItemReader<'a, I>
25where
26    for<'r> I: FromRow<'r, PgRow> + Send + Unpin + Clone,
27{
28    pub(crate) pool: Pool<Postgres>,
29    pub(crate) query: &'a str,
30    pub(crate) page_size: Option<i32>,
31    pub(crate) offset: Cell<i32>,
32    pub(crate) buffer: RefCell<Vec<I>>,
33    /// Column name used as the keyset cursor (e.g. `"id"`).
34    pub(crate) keyset_column: Option<String>,
35    /// Extracts the cursor value from an item for use in the next page's WHERE clause.
36    #[allow(clippy::type_complexity)]
37    pub(crate) keyset_key: Option<Box<dyn Fn(&I) -> String>>,
38    /// Last cursor value seen; drives the WHERE clause on subsequent pages.
39    pub(crate) last_cursor: RefCell<Option<String>>,
40}
41
42impl<'a, I> PostgresRdbcItemReader<'a, I>
43where
44    for<'r> I: FromRow<'r, PgRow> + Send + Unpin + Clone,
45{
46    /// Creates a new PostgresRdbcItemReader with the specified parameters
47    ///
48    /// This constructor is only accessible within the crate to enforce the use
49    /// of `RdbcItemReaderBuilder` for creating reader instances.
50    ///
51    /// # Arguments
52    ///
53    /// * `pool` - PostgreSQL connection pool for database operations
54    /// * `query` - SQL query to execute (without LIMIT/OFFSET)
55    /// * `page_size` - Optional page size for pagination. None means read all at once.
56    ///
57    /// # Returns
58    ///
59    /// A new `PostgresRdbcItemReader` instance ready for use.
60    #[allow(clippy::type_complexity)]
61    pub fn new(
62        pool: Pool<Postgres>,
63        query: &'a str,
64        page_size: Option<i32>,
65        keyset_column: Option<String>,
66        keyset_key: Option<Box<dyn Fn(&I) -> String>>,
67    ) -> Self {
68        Self {
69            pool,
70            query,
71            page_size,
72            offset: Cell::new(0),
73            buffer: RefCell::new(vec![]),
74            keyset_column,
75            keyset_key,
76            last_cursor: RefCell::new(None),
77        }
78    }
79
80    /// Fetches the next page from the database into the internal buffer.
81    ///
82    /// Uses keyset pagination when [`keyset_column`] is set, otherwise falls back
83    /// to LIMIT/OFFSET.
84    ///
85    /// # Errors
86    ///
87    /// Returns [`BatchError::ItemReader`] if the query fails.
88    fn read_page(&self) -> Result<(), BatchError> {
89        let mut query_builder = QueryBuilder::<Postgres>::new(self.query);
90
91        if let Some(page_size) = self.page_size {
92            if let Some(ref col) = self.keyset_column {
93                let last = self.last_cursor.borrow();
94                if let Some(ref cursor_val) = *last {
95                    // Escape single quotes to prevent SQL injection from cursor values.
96                    let escaped = cursor_val.replace('\'', "''");
97                    query_builder.push(format!(" WHERE {} > '{}'", col, escaped));
98                }
99                query_builder.push(format!(" ORDER BY {} LIMIT {}", col, page_size));
100            } else {
101                query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
102            }
103        }
104
105        let query = query_builder.build();
106
107        let items = tokio::task::block_in_place(|| {
108            tokio::runtime::Handle::current().block_on(async {
109                sqlx::query_as::<_, I>(query.sql())
110                    .fetch_all(&self.pool)
111                    .await
112                    .map_err(|e| BatchError::ItemReader(e.to_string()))
113            })
114        })?;
115
116        self.buffer.borrow_mut().clear();
117        self.buffer.borrow_mut().extend(items);
118        Ok(())
119    }
120}
121
122/// Implementation of ItemReader trait for PostgresRdbcItemReader.
123///
124/// This implementation provides a way to read items from a PostgreSQL database
125/// with support for pagination. It uses an internal buffer to store the results
126/// of database queries and keeps track of the current offset to determine when
127/// a new page of data needs to be fetched.
128///
129/// The implementation handles both paginated and non-paginated reading modes
130/// transparently, making it suitable for various batch processing scenarios.
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use sqlx::PgPool;
135
136    #[derive(Clone)]
137    struct Dummy;
138
139    impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Dummy {
140        fn from_row(_row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
141            Ok(Dummy)
142        }
143    }
144
145    fn reader_with_keyset(keyset: bool) -> PostgresRdbcItemReader<'static, Dummy> {
146        let pool = PgPool::connect_lazy("postgres://postgres:postgres@localhost/test")
147            .expect("lazy pool creation should not fail");
148        let (col, key): (Option<String>, Option<Box<dyn Fn(&Dummy) -> String>>) = if keyset {
149            (
150                Some("id".to_string()),
151                Some(Box::new(|_: &Dummy| "0".to_string())),
152            )
153        } else {
154            (None, None)
155        };
156        PostgresRdbcItemReader::new(pool, "SELECT 1", Some(10), col, key)
157    }
158
159    #[tokio::test(flavor = "multi_thread")]
160    async fn should_initialize_without_keyset() {
161        let reader = reader_with_keyset(false);
162        assert!(reader.keyset_column.is_none(), "no keyset column expected");
163        assert!(reader.keyset_key.is_none(), "no keyset key fn expected");
164        assert!(
165            reader.last_cursor.borrow().is_none(),
166            "cursor must start as None"
167        );
168        assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
169        assert!(
170            reader.buffer.borrow().is_empty(),
171            "buffer should start empty"
172        );
173        assert_eq!(reader.page_size, Some(10));
174    }
175
176    #[tokio::test(flavor = "multi_thread")]
177    async fn should_initialize_with_keyset_column_and_none_cursor() {
178        let reader = reader_with_keyset(true);
179        assert_eq!(
180            reader.keyset_column.as_deref(),
181            Some("id"),
182            "keyset column should be stored"
183        );
184        assert!(
185            reader.keyset_key.is_some(),
186            "keyset key fn should be stored"
187        );
188        assert!(
189            reader.last_cursor.borrow().is_none(),
190            "cursor must start as None before first read"
191        );
192    }
193}
194
195impl<I> ItemReader<I> for PostgresRdbcItemReader<'_, I>
196where
197    for<'r> I: FromRow<'r, PgRow> + Send + Unpin + Clone,
198{
199    /// Reads the next item from the PostgreSQL database
200    ///
201    /// This method implements the ItemReader trait and provides the core reading logic
202    /// with automatic pagination management:
203    ///
204    /// 1. **Index Calculation**: Determines the current position within the current page
205    /// 2. **Page Loading**: Loads a new page if we're at the beginning of a page
206    /// 3. **Item Retrieval**: Returns the item at the current position from the buffer
207    /// 4. **Offset Management**: Advances the offset for the next read operation
208    ///
209    /// # Pagination Logic
210    ///
211    /// For paginated reading (when page_size is Some):
212    /// - `index = offset % page_size` gives position within current page
213    /// - When `index == 0`, we're at the start of a new page and need to load data
214    /// - Buffer contains only the current page's items
215    ///
216    /// For non-paginated reading (when page_size is None):
217    /// - `index = offset` gives absolute position in the full result set
218    /// - Data is loaded only once when `index == 0` (first read)
219    /// - Buffer contains all items from the query
220    ///
221    /// # Returns
222    ///
223    /// - `Ok(Some(item))` if an item was successfully read
224    /// - `Ok(None)` if there are no more items to read (end of result set)
225    /// - `Err(BatchError::ItemReader)` if a database error occurred
226    ///
227    /// # Examples
228    ///
229    /// ```ignore
230    /// use spring_batch_rs::core::item::ItemReader;
231    /// # use spring_batch_rs::item::rdbc::PostgresRdbcItemReader;
232    /// # use sqlx::PgPool;
233    /// # use serde::Deserialize;
234    /// # #[derive(sqlx::FromRow, Clone, Deserialize)]
235    /// # struct User { id: i32, name: String }
236    ///
237    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
238    /// # let pool = PgPool::connect("postgresql://user:pass@localhost/db").await?;
239    /// let reader = PostgresRdbcItemReader::<User>::new(
240    ///     pool,
241    ///     "SELECT id, name FROM users ORDER BY id",
242    ///     Some(100)
243    /// );
244    ///
245    /// // Read items one by one
246    /// let mut count = 0;
247    /// while let Some(user) = reader.read()? {
248    ///     println!("User: {} - {}", user.id, user.name);
249    ///     count += 1;
250    /// }
251    /// println!("Processed {} users", count);
252    /// # Ok(())
253    /// # }
254    /// ```
255    fn read(&self) -> ItemReaderResult<I> {
256        let index = calculate_page_index(self.offset.get(), self.page_size);
257
258        if should_load_page(index) {
259            self.read_page()?;
260        }
261
262        let result = self.buffer.borrow().get(index as usize).cloned();
263
264        if let (Some(item), Some(key_fn)) = (&result, &self.keyset_key) {
265            *self.last_cursor.borrow_mut() = Some(key_fn(item));
266        }
267
268        self.offset.set(self.offset.get() + 1);
269
270        Ok(result)
271    }
272}