xitca_postgres/query/
stream.rs

1use core::{
2    future::Future,
3    marker::PhantomData,
4    ops::Range,
5    pin::Pin,
6    task::{ready, Context, Poll},
7};
8
9use std::sync::Arc;
10
11use fallible_iterator::FallibleIterator;
12use postgres_protocol::message::backend;
13
14use crate::{
15    column::Column,
16    driver::codec::Response,
17    error::Error,
18    iter::AsyncLendingIterator,
19    prepare::Prepare,
20    row::{marker, Row, RowOwned, RowSimple, RowSimpleOwned},
21    types::Type,
22};
23
24#[derive(Debug)]
25pub struct GenericRowStream<C, M> {
26    pub(crate) res: Response,
27    pub(crate) col: C,
28    pub(crate) ranges: Vec<Range<usize>>,
29    pub(crate) _marker: PhantomData<M>,
30}
31
32impl<C, M> GenericRowStream<C, M> {
33    pub(crate) fn new(res: Response, col: C) -> Self {
34        Self {
35            res,
36            col,
37            ranges: Vec::new(),
38            _marker: PhantomData,
39        }
40    }
41}
42
43/// A stream of table rows.
44pub type RowStream<'a> = GenericRowStream<&'a [Column], marker::Typed>;
45
46impl<'a> AsyncLendingIterator for RowStream<'a> {
47    type Ok<'i>
48        = Row<'i>
49    where
50        Self: 'i;
51    type Err = Error;
52
53    #[inline]
54    fn try_next(&mut self) -> impl Future<Output = Result<Option<Self::Ok<'_>>, Self::Err>> + Send {
55        try_next(&mut self.res, self.col, &mut self.ranges)
56    }
57}
58
59async fn try_next<'r>(
60    res: &mut Response,
61    col: &'r [Column],
62    ranges: &'r mut Vec<Range<usize>>,
63) -> Result<Option<Row<'r>>, Error> {
64    loop {
65        match res.recv().await? {
66            backend::Message::DataRow(body) => return Row::try_new(col, body, ranges).map(Some),
67            backend::Message::BindComplete
68            | backend::Message::EmptyQueryResponse
69            | backend::Message::CommandComplete(_)
70            | backend::Message::PortalSuspended => {}
71            backend::Message::ReadyForQuery(_) => return Ok(None),
72            _ => return Err(Error::unexpected()),
73        }
74    }
75}
76
77/// [`RowStream`] with static lifetime
78///
79/// # Usage
80/// due to Rust's GAT limitation [`AsyncLendingIterator`] only works well type that have static lifetime.
81/// actively converting a [`RowStream`] to [`RowStreamOwned`] will opens up convenient high level APIs at some additional
82/// cost (More memory allocation)
83///
84/// # Examples
85/// ```
86/// # use xitca_postgres::{iter::{AsyncLendingIterator, AsyncLendingIteratorExt}, Client, Error, Execute, RowStreamOwned, Statement};
87/// # async fn collect(cli: Client) -> Result<(), Error> {
88/// // prepare statement and query for some users from database.
89/// let stmt = Statement::named("SELECT * FROM users", &[]).execute(&cli).await?;
90/// let mut stream = stmt.query(&cli).await?;
91///
92/// // assuming users contain name column where it can be parsed to string.
93/// // then collecting all user name to a collection
94/// let mut strings = Vec::new();
95/// while let Some(row) = stream.try_next().await? {
96///     strings.push(row.get::<String>("name"));
97/// }
98///
99/// // the same operation with owned row stream can be simplified a bit:
100/// let stream = stmt.query(&cli).await?;
101/// // use extended api on top of AsyncIterator to collect user names to collection
102/// let strings_2: Vec<String> = RowStreamOwned::from(stream).map_ok(|row| row.get("name")).try_collect().await?;
103///
104/// assert_eq!(strings, strings_2);
105/// # Ok(())
106/// # }
107/// ```
108pub type RowStreamOwned = GenericRowStream<Arc<[Column]>, marker::Typed>;
109
110impl From<RowStream<'_>> for RowStreamOwned {
111    fn from(stream: RowStream<'_>) -> Self {
112        Self {
113            res: stream.res,
114            col: Arc::from(stream.col),
115            ranges: stream.ranges,
116            _marker: PhantomData,
117        }
118    }
119}
120
121impl AsyncLendingIterator for RowStreamOwned {
122    type Ok<'i>
123        = Row<'i>
124    where
125        Self: 'i;
126    type Err = Error;
127
128    #[inline]
129    fn try_next(&mut self) -> impl Future<Output = Result<Option<Self::Ok<'_>>, Self::Err>> + Send {
130        try_next(&mut self.res, &self.col, &mut self.ranges)
131    }
132}
133
134impl IntoIterator for RowStream<'_> {
135    type Item = Result<RowOwned, Error>;
136    type IntoIter = RowStreamOwned;
137
138    fn into_iter(self) -> Self::IntoIter {
139        RowStreamOwned::from(self)
140    }
141}
142
143impl Iterator for RowStreamOwned {
144    type Item = Result<RowOwned, Error>;
145
146    fn next(&mut self) -> Option<Self::Item> {
147        loop {
148            match self.res.blocking_recv() {
149                Ok(msg) => match msg {
150                    backend::Message::DataRow(body) => {
151                        return Some(RowOwned::try_new(self.col.clone(), body, Vec::new()))
152                    }
153                    backend::Message::BindComplete
154                    | backend::Message::EmptyQueryResponse
155                    | backend::Message::CommandComplete(_)
156                    | backend::Message::PortalSuspended => {}
157                    backend::Message::ReadyForQuery(_) => return None,
158                    _ => return Some(Err(Error::unexpected())),
159                },
160                Err(e) => return Some(Err(e)),
161            }
162        }
163    }
164}
165
166/// A stream of simple query results.
167pub type RowSimpleStream = GenericRowStream<Vec<Column>, marker::NoTyped>;
168
169impl AsyncLendingIterator for RowSimpleStream {
170    type Ok<'i>
171        = RowSimple<'i>
172    where
173        Self: 'i;
174    type Err = Error;
175
176    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
177        loop {
178            match self.res.recv().await? {
179                backend::Message::RowDescription(body) => {
180                    self.col = body
181                        .fields()
182                        // text type is used to match RowSimple::try_get's implementation
183                        // where column's pg type is always assumed as Option<&str>.
184                        // (no runtime pg type check so this does not really matter. it's
185                        // better to keep the type consistent though)
186                        .map(|f| Ok(Column::new(f.name(), Type::TEXT)))
187                        .collect::<Vec<_>>()?;
188                }
189                backend::Message::DataRow(body) => {
190                    return RowSimple::try_new(&self.col, body, &mut self.ranges).map(Some);
191                }
192                backend::Message::CommandComplete(_) | backend::Message::EmptyQueryResponse => {}
193                backend::Message::ReadyForQuery(_) => return Ok(None),
194                _ => return Err(Error::unexpected()),
195            }
196        }
197    }
198}
199
200/// [`RowSimpleStreamOwned`] with static lifetime
201pub type RowSimpleStreamOwned = GenericRowStream<Arc<[Column]>, marker::NoTyped>;
202
203impl From<RowSimpleStream> for RowSimpleStreamOwned {
204    fn from(stream: RowSimpleStream) -> Self {
205        Self {
206            res: stream.res,
207            col: stream.col.into(),
208            ranges: stream.ranges,
209            _marker: PhantomData,
210        }
211    }
212}
213
214impl IntoIterator for RowSimpleStream {
215    type IntoIter = RowSimpleStreamOwned;
216    type Item = Result<RowSimpleOwned, Error>;
217
218    fn into_iter(self) -> Self::IntoIter {
219        RowSimpleStreamOwned::from(self)
220    }
221}
222
223impl Iterator for RowSimpleStreamOwned {
224    type Item = Result<RowSimpleOwned, Error>;
225
226    fn next(&mut self) -> Option<Self::Item> {
227        loop {
228            match self.res.blocking_recv() {
229                Ok(msg) => match msg {
230                    backend::Message::RowDescription(body) => match body
231                        .fields()
232                        .map(|f| Ok(Column::new(f.name(), Type::TEXT)))
233                        .collect::<Vec<_>>()
234                    {
235                        Ok(col) => self.col = col.into(),
236                        Err(e) => return Some(Err(Error::from(e))),
237                    },
238                    backend::Message::DataRow(body) => {
239                        return Some(RowSimpleOwned::try_new(self.col.clone(), body, Vec::new()));
240                    }
241                    backend::Message::CommandComplete(_)
242                    | backend::Message::EmptyQueryResponse
243                    | backend::Message::ReadyForQuery(_) => return None,
244                    _ => return Some(Err(Error::unexpected())),
245                },
246                Err(e) => return Some(Err(e)),
247            }
248        }
249    }
250}
251
252/// a stream of table rows where column type looked up and row parsing are bundled together
253pub struct RowStreamGuarded<'a, C> {
254    pub(crate) res: Response,
255    pub(crate) col: Vec<Column>,
256    pub(crate) ranges: Vec<Range<usize>>,
257    pub(crate) cli: &'a C,
258}
259
260impl<'a, C> RowStreamGuarded<'a, C> {
261    pub(crate) fn new(res: Response, cli: &'a C) -> Self {
262        Self {
263            res,
264            col: Vec::new(),
265            ranges: Vec::new(),
266            cli,
267        }
268    }
269}
270
271impl<C> AsyncLendingIterator for RowStreamGuarded<'_, C>
272where
273    C: Prepare + Sync,
274{
275    type Ok<'i>
276        = Row<'i>
277    where
278        Self: 'i;
279    type Err = Error;
280
281    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
282        loop {
283            match self.res.recv().await? {
284                backend::Message::RowDescription(body) => {
285                    let mut it = body.fields();
286                    while let Some(field) = it.next()? {
287                        let ty = self.cli._get_type(field.type_oid()).await?;
288                        self.col.push(Column::new(field.name(), ty));
289                    }
290                }
291                backend::Message::DataRow(body) => return Row::try_new(&self.col, body, &mut self.ranges).map(Some),
292                backend::Message::ParseComplete
293                | backend::Message::BindComplete
294                | backend::Message::ParameterDescription(_)
295                | backend::Message::EmptyQueryResponse
296                | backend::Message::CommandComplete(_)
297                | backend::Message::PortalSuspended
298                | backend::Message::NoData => {}
299                backend::Message::ReadyForQuery(_) => return Ok(None),
300                _ => return Err(Error::unexpected()),
301            }
302        }
303    }
304}
305
306pub struct RowStreamGuardedOwned<'a, C> {
307    res: Response,
308    col: Arc<[Column]>,
309    cli: &'a C,
310}
311
312impl<'a, C> From<RowStreamGuarded<'a, C>> for RowStreamGuardedOwned<'a, C> {
313    fn from(stream: RowStreamGuarded<'a, C>) -> Self {
314        Self {
315            res: stream.res,
316            col: stream.col.into(),
317            cli: stream.cli,
318        }
319    }
320}
321
322impl<'a, C> IntoIterator for RowStreamGuarded<'a, C>
323where
324    C: Prepare,
325{
326    type Item = Result<RowOwned, Error>;
327    type IntoIter = RowStreamGuardedOwned<'a, C>;
328
329    fn into_iter(self) -> Self::IntoIter {
330        RowStreamGuardedOwned::from(self)
331    }
332}
333
334impl<C> Iterator for RowStreamGuardedOwned<'_, C>
335where
336    C: Prepare,
337{
338    type Item = Result<RowOwned, Error>;
339
340    fn next(&mut self) -> Option<Self::Item> {
341        loop {
342            match self.res.blocking_recv() {
343                Ok(msg) => match msg {
344                    backend::Message::RowDescription(body) => {
345                        match body
346                            .fields()
347                            .map_err(Error::from)
348                            .map(|f| {
349                                let ty = self.cli._get_type_blocking(f.type_oid())?;
350                                Ok(Column::new(f.name(), ty))
351                            })
352                            .collect::<Vec<_>>()
353                        {
354                            Ok(col) => self.col = col.into(),
355                            Err(e) => return Some(Err(e)),
356                        }
357                    }
358                    backend::Message::DataRow(body) => {
359                        return Some(RowOwned::try_new(self.col.clone(), body, Vec::new()))
360                    }
361                    backend::Message::ParseComplete
362                    | backend::Message::BindComplete
363                    | backend::Message::ParameterDescription(_)
364                    | backend::Message::EmptyQueryResponse
365                    | backend::Message::CommandComplete(_)
366                    | backend::Message::PortalSuspended => {}
367                    backend::Message::NoData | backend::Message::ReadyForQuery(_) => return None,
368                    _ => return Some(Err(Error::unexpected())),
369                },
370                Err(e) => return Some(Err(e)),
371            }
372        }
373    }
374}
375
376pub struct RowAffected {
377    res: Response,
378    rows_affected: u64,
379}
380
381impl Future for RowAffected {
382    type Output = Result<u64, Error>;
383
384    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
385        let this = self.get_mut();
386        ready!(this.res.poll_try_into_ready(&mut this.rows_affected, cx))?;
387        Poll::Ready(Ok(this.rows_affected))
388    }
389}
390
391impl RowAffected {
392    pub(crate) fn wait(self) -> Result<u64, Error> {
393        self.res.try_into_row_affected_blocking()
394    }
395}
396
397impl<C, M> From<GenericRowStream<C, M>> for RowAffected {
398    fn from(stream: GenericRowStream<C, M>) -> Self {
399        Self {
400            res: stream.res,
401            rows_affected: 0,
402        }
403    }
404}
405
406impl<C> From<RowStreamGuarded<'_, C>> for RowAffected {
407    fn from(stream: RowStreamGuarded<'_, C>) -> Self {
408        Self {
409            res: stream.res,
410            rows_affected: 0,
411        }
412    }
413}