tokio_postgres_cursor/cursor/
stream.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4    task::{Context, Poll, ready},
5};
6
7use rand::{
8    distr::{Alphanumeric, Distribution},
9    rng,
10};
11
12use futures_core::Stream;
13use tokio_postgres::{Error, Row, Transaction};
14
15/// A stream that fetches rows from a PostgreSQL cursor in batches.
16pub struct CursorStream<'a> {
17    tx: Arc<&'a Transaction<'a>>,
18    cursor: Arc<String>,
19    batch_size: usize,
20    future: Option<Pin<Box<dyn Future<Output = Result<Vec<Row>, Error>> + Send + 'a>>>,
21    done: bool,
22}
23
24impl<'a> CursorStream<'a> {
25    /// Creates a new [`CursorStream`] and declares a cursor for the given query.
26    ///
27    /// Parameters:
28    /// - `tx`: A reference to the transaction in which the cursor will be declared.
29    /// - `query`: The SQL query for which the cursor will be declared.
30    /// - `batch_size`: The number of rows to fetch in each batch.
31    ///
32    /// Errors:
33    /// - Propagates
34    /// [`tokio_postgres::Error`](https://docs.rs/tokio-postgres/latest/tokio_postgres/error/struct.Error.html)
35    /// if the cursor declaration fails.
36    pub(crate) async fn new(
37        tx: &'a Transaction<'a>,
38        query: &str,
39        batch_size: usize,
40    ) -> Result<Self, Error> {
41        let cursor = format!(
42            "cursor_{}",
43            Alphanumeric
44                .sample_iter(rng())
45                .take(3)
46                .map(|x| x as char)
47                .collect::<String>()
48        );
49        tx.execute(
50            &format!("DECLARE {} NO SCROLL CURSOR FOR {}", cursor, query),
51            &[],
52        )
53        .await?;
54
55        Ok(Self {
56            tx: Arc::new(tx),
57            cursor: Arc::new(cursor),
58            batch_size,
59            future: None,
60            done: false,
61        })
62    }
63
64    /// Closes the cursor associated with this stream.
65    ///
66    /// Errors:
67    /// - Propagates
68    /// [`tokio_postgres::Error`](https://docs.rs/tokio-postgres/latest/tokio_postgres/error/struct.Error.html)
69    /// if the cursor closing fails
70    pub async fn close(mut self) -> Result<u64, Error> {
71        self.done = true;
72        self.tx
73            .execute(&format!("CLOSE {}", self.cursor), &[])
74            .await
75    }
76}
77
78impl<'a> Stream for CursorStream<'a> {
79    type Item = Result<Vec<Row>, Error>;
80
81    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82        if self.done {
83            return Poll::Ready(None);
84        }
85
86        if self.future.is_none() {
87            let tx = self.tx.clone();
88            let cursor = self.cursor.clone();
89            let batch_size = self.batch_size;
90
91            let future = Box::pin(async move {
92                tx.query(
93                    &format!("FETCH FORWARD {} FROM {}", batch_size, cursor),
94                    &[],
95                )
96                .await
97            });
98
99            self.future = Some(future);
100        }
101
102        match ready!(self.future.as_mut().unwrap().as_mut().poll(cx)) {
103            Ok(rows) => {
104                self.future = None;
105                if rows.is_empty() {
106                    Poll::Ready(None)
107                } else {
108                    Poll::Ready(Some(Ok(rows)))
109                }
110            }
111            Err(e) => {
112                self.future = None;
113                Poll::Ready(Some(Err(e)))
114            }
115        }
116    }
117}