yb_tokio_postgres/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures_util::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use postgres_protocol::message::backend::{CommandCompleteBody, Message};
11use postgres_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21    T: BorrowToSql,
22{
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_list()
25            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26            .finish()
27    }
28}
29
30pub async fn query<P, I>(
31    client: &InnerClient,
32    statement: Statement,
33    params: I,
34) -> Result<RowStream, Error>
35where
36    P: BorrowToSql,
37    I: IntoIterator<Item = P>,
38    I::IntoIter: ExactSizeIterator,
39{
40    let buf = if log_enabled!(Level::Debug) {
41        let params = params.into_iter().collect::<Vec<_>>();
42        debug!(
43            "executing statement {} with parameters: {:?}",
44            statement.name(),
45            BorrowToSqlParamsDebug(params.as_slice()),
46        );
47        encode(client, &statement, params)?
48    } else {
49        encode(client, &statement, params)?
50    };
51    let responses = start(client, buf).await?;
52    Ok(RowStream {
53        statement,
54        responses,
55        rows_affected: None,
56        _p: PhantomPinned,
57    })
58}
59
60pub async fn query_portal(
61    client: &InnerClient,
62    portal: &Portal,
63    max_rows: i32,
64) -> Result<RowStream, Error> {
65    let buf = client.with_buf(|buf| {
66        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
67        frontend::sync(buf);
68        Ok(buf.split().freeze())
69    })?;
70
71    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
72
73    Ok(RowStream {
74        statement: portal.statement().clone(),
75        responses,
76        rows_affected: None,
77        _p: PhantomPinned,
78    })
79}
80
81/// Extract the number of rows affected from [`CommandCompleteBody`].
82pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
83    let rows = body
84        .tag()
85        .map_err(Error::parse)?
86        .rsplit(' ')
87        .next()
88        .unwrap()
89        .parse()
90        .unwrap_or(0);
91    Ok(rows)
92}
93
94pub async fn execute<P, I>(
95    client: &InnerClient,
96    statement: Statement,
97    params: I,
98) -> Result<u64, Error>
99where
100    P: BorrowToSql,
101    I: IntoIterator<Item = P>,
102    I::IntoIter: ExactSizeIterator,
103{
104    let buf = if log_enabled!(Level::Debug) {
105        let params = params.into_iter().collect::<Vec<_>>();
106        debug!(
107            "executing statement {} with parameters: {:?}",
108            statement.name(),
109            BorrowToSqlParamsDebug(params.as_slice()),
110        );
111        encode(client, &statement, params)?
112    } else {
113        encode(client, &statement, params)?
114    };
115    let mut responses = start(client, buf).await?;
116
117    let mut rows = 0;
118    loop {
119        match responses.next().await? {
120            Message::DataRow(_) => {}
121            Message::CommandComplete(body) => {
122                rows = extract_row_affected(&body)?;
123            }
124            Message::EmptyQueryResponse => rows = 0,
125            Message::ReadyForQuery(_) => return Ok(rows),
126            _ => return Err(Error::unexpected_message()),
127        }
128    }
129}
130
131async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
132    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
133
134    match responses.next().await? {
135        Message::BindComplete => {}
136        _ => return Err(Error::unexpected_message()),
137    }
138
139    Ok(responses)
140}
141
142pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
143where
144    P: BorrowToSql,
145    I: IntoIterator<Item = P>,
146    I::IntoIter: ExactSizeIterator,
147{
148    client.with_buf(|buf| {
149        encode_bind(statement, params, "", buf)?;
150        frontend::execute("", 0, buf).map_err(Error::encode)?;
151        frontend::sync(buf);
152        Ok(buf.split().freeze())
153    })
154}
155
156pub fn encode_bind<P, I>(
157    statement: &Statement,
158    params: I,
159    portal: &str,
160    buf: &mut BytesMut,
161) -> Result<(), Error>
162where
163    P: BorrowToSql,
164    I: IntoIterator<Item = P>,
165    I::IntoIter: ExactSizeIterator,
166{
167    let param_types = statement.params();
168    let params = params.into_iter();
169
170    if param_types.len() != params.len() {
171        return Err(Error::parameters(params.len(), param_types.len()));
172    }
173
174    let (param_formats, params): (Vec<_>, Vec<_>) = params
175        .zip(param_types.iter())
176        .map(|(p, ty)| (p.borrow_to_sql().encode_format(ty) as i16, p))
177        .unzip();
178
179    let params = params.into_iter();
180
181    let mut error_idx = 0;
182    let r = frontend::bind(
183        portal,
184        statement.name(),
185        param_formats,
186        params.zip(param_types).enumerate(),
187        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
188            Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
189            Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
190            Err(e) => {
191                error_idx = idx;
192                Err(e)
193            }
194        },
195        Some(1),
196        buf,
197    );
198    match r {
199        Ok(()) => Ok(()),
200        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
201        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
202    }
203}
204
205pin_project! {
206    /// A stream of table rows.
207    pub struct RowStream {
208        statement: Statement,
209        responses: Responses,
210        rows_affected: Option<u64>,
211        #[pin]
212        _p: PhantomPinned,
213    }
214}
215
216impl Stream for RowStream {
217    type Item = Result<Row, Error>;
218
219    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220        let this = self.project();
221        loop {
222            match ready!(this.responses.poll_next(cx)?) {
223                Message::DataRow(body) => {
224                    return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
225                }
226                Message::CommandComplete(body) => {
227                    *this.rows_affected = Some(extract_row_affected(&body)?);
228                }
229                Message::EmptyQueryResponse | Message::PortalSuspended => {}
230                Message::ReadyForQuery(_) => return Poll::Ready(None),
231                _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
232            }
233        }
234    }
235}
236
237impl RowStream {
238    /// Returns the number of rows affected by the query.
239    ///
240    /// This function will return `None` until the stream has been exhausted.
241    pub fn rows_affected(&self) -> Option<u64> {
242        self.rows_affected
243    }
244}