tokio_opengauss/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use opengauss_protocol::message::backend::Message;
11use opengauss_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21    T: BorrowToSql,
22{
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_list()
25            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26            .finish()
27    }
28}
29
30pub async fn query<P, I>(
31    client: &InnerClient,
32    statement: Statement,
33    params: I,
34) -> Result<RowStream, Error>
35where
36    P: BorrowToSql,
37    I: IntoIterator<Item = P>,
38    I::IntoIter: ExactSizeIterator,
39{
40    let buf = if log_enabled!(Level::Debug) {
41        let params = params.into_iter().collect::<Vec<_>>();
42        debug!(
43            "executing statement {} with parameters: {:?}",
44            statement.name(),
45            BorrowToSqlParamsDebug(params.as_slice()),
46        );
47        encode(client, &statement, params)?
48    } else {
49        encode(client, &statement, params)?
50    };
51    let responses = start(client, buf).await?;
52    Ok(RowStream {
53        statement,
54        responses,
55        _p: PhantomPinned,
56    })
57}
58
59pub async fn query_portal(
60    client: &InnerClient,
61    portal: &Portal,
62    max_rows: i32,
63) -> Result<RowStream, Error> {
64    let buf = client.with_buf(|buf| {
65        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
66        frontend::sync(buf);
67        Ok(buf.split().freeze())
68    })?;
69
70    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
71
72    Ok(RowStream {
73        statement: portal.statement().clone(),
74        responses,
75        _p: PhantomPinned,
76    })
77}
78
79pub async fn execute<P, I>(
80    client: &InnerClient,
81    statement: Statement,
82    params: I,
83) -> Result<u64, Error>
84where
85    P: BorrowToSql,
86    I: IntoIterator<Item = P>,
87    I::IntoIter: ExactSizeIterator,
88{
89    let buf = if log_enabled!(Level::Debug) {
90        let params = params.into_iter().collect::<Vec<_>>();
91        debug!(
92            "executing statement {} with parameters: {:?}",
93            statement.name(),
94            BorrowToSqlParamsDebug(params.as_slice()),
95        );
96        encode(client, &statement, params)?
97    } else {
98        encode(client, &statement, params)?
99    };
100    let mut responses = start(client, buf).await?;
101
102    loop {
103        match responses.next().await? {
104            Message::DataRow(_) => {}
105            Message::CommandComplete(body) => {
106                let rows = body
107                    .tag()
108                    .map_err(Error::parse)?
109                    .rsplit(' ')
110                    .next()
111                    .unwrap()
112                    .parse()
113                    .unwrap_or(0);
114                return Ok(rows);
115            }
116            Message::EmptyQueryResponse => return Ok(0),
117            _ => return Err(Error::unexpected_message()),
118        }
119    }
120}
121
122async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
123    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
124
125    match responses.next().await? {
126        Message::BindComplete => {}
127        _ => return Err(Error::unexpected_message()),
128    }
129
130    Ok(responses)
131}
132
133pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
134where
135    P: BorrowToSql,
136    I: IntoIterator<Item = P>,
137    I::IntoIter: ExactSizeIterator,
138{
139    client.with_buf(|buf| {
140        encode_bind(statement, params, "", buf)?;
141        frontend::execute("", 0, buf).map_err(Error::encode)?;
142        frontend::sync(buf);
143        Ok(buf.split().freeze())
144    })
145}
146
147pub fn encode_bind<P, I>(
148    statement: &Statement,
149    params: I,
150    portal: &str,
151    buf: &mut BytesMut,
152) -> Result<(), Error>
153where
154    P: BorrowToSql,
155    I: IntoIterator<Item = P>,
156    I::IntoIter: ExactSizeIterator,
157{
158    let params = params.into_iter();
159
160    assert!(
161        statement.params().len() == params.len(),
162        "expected {} parameters but got {}",
163        statement.params().len(),
164        params.len()
165    );
166
167    let mut error_idx = 0;
168    let r = frontend::bind(
169        portal,
170        statement.name(),
171        Some(1),
172        params.zip(statement.params()).enumerate(),
173        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
174            Ok(IsNull::No) => Ok(opengauss_protocol::IsNull::No),
175            Ok(IsNull::Yes) => Ok(opengauss_protocol::IsNull::Yes),
176            Err(e) => {
177                error_idx = idx;
178                Err(e)
179            }
180        },
181        Some(1),
182        buf,
183    );
184    match r {
185        Ok(()) => Ok(()),
186        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
187        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
188    }
189}
190
191pin_project! {
192    /// A stream of table rows.
193    pub struct RowStream {
194        statement: Statement,
195        responses: Responses,
196        #[pin]
197        _p: PhantomPinned,
198    }
199}
200
201impl Stream for RowStream {
202    type Item = Result<Row, Error>;
203
204    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
205        let this = self.project();
206        match ready!(this.responses.poll_next(cx)?) {
207            Message::DataRow(body) => {
208                Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
209            }
210            Message::EmptyQueryResponse
211            | Message::CommandComplete(_)
212            | Message::PortalSuspended => Poll::Ready(None),
213            Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
214            _ => Poll::Ready(Some(Err(Error::unexpected_message()))),
215        }
216    }
217}