1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use opengauss_protocol::message::backend::Message;
11use opengauss_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21 T: BorrowToSql,
22{
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_list()
25 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26 .finish()
27 }
28}
29
30pub async fn query<P, I>(
31 client: &InnerClient,
32 statement: Statement,
33 params: I,
34) -> Result<RowStream, Error>
35where
36 P: BorrowToSql,
37 I: IntoIterator<Item = P>,
38 I::IntoIter: ExactSizeIterator,
39{
40 let buf = if log_enabled!(Level::Debug) {
41 let params = params.into_iter().collect::<Vec<_>>();
42 debug!(
43 "executing statement {} with parameters: {:?}",
44 statement.name(),
45 BorrowToSqlParamsDebug(params.as_slice()),
46 );
47 encode(client, &statement, params)?
48 } else {
49 encode(client, &statement, params)?
50 };
51 let responses = start(client, buf).await?;
52 Ok(RowStream {
53 statement,
54 responses,
55 _p: PhantomPinned,
56 })
57}
58
59pub async fn query_portal(
60 client: &InnerClient,
61 portal: &Portal,
62 max_rows: i32,
63) -> Result<RowStream, Error> {
64 let buf = client.with_buf(|buf| {
65 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
66 frontend::sync(buf);
67 Ok(buf.split().freeze())
68 })?;
69
70 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
71
72 Ok(RowStream {
73 statement: portal.statement().clone(),
74 responses,
75 _p: PhantomPinned,
76 })
77}
78
79pub async fn execute<P, I>(
80 client: &InnerClient,
81 statement: Statement,
82 params: I,
83) -> Result<u64, Error>
84where
85 P: BorrowToSql,
86 I: IntoIterator<Item = P>,
87 I::IntoIter: ExactSizeIterator,
88{
89 let buf = if log_enabled!(Level::Debug) {
90 let params = params.into_iter().collect::<Vec<_>>();
91 debug!(
92 "executing statement {} with parameters: {:?}",
93 statement.name(),
94 BorrowToSqlParamsDebug(params.as_slice()),
95 );
96 encode(client, &statement, params)?
97 } else {
98 encode(client, &statement, params)?
99 };
100 let mut responses = start(client, buf).await?;
101
102 loop {
103 match responses.next().await? {
104 Message::DataRow(_) => {}
105 Message::CommandComplete(body) => {
106 let rows = body
107 .tag()
108 .map_err(Error::parse)?
109 .rsplit(' ')
110 .next()
111 .unwrap()
112 .parse()
113 .unwrap_or(0);
114 return Ok(rows);
115 }
116 Message::EmptyQueryResponse => return Ok(0),
117 _ => return Err(Error::unexpected_message()),
118 }
119 }
120}
121
122async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
123 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
124
125 match responses.next().await? {
126 Message::BindComplete => {}
127 _ => return Err(Error::unexpected_message()),
128 }
129
130 Ok(responses)
131}
132
133pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
134where
135 P: BorrowToSql,
136 I: IntoIterator<Item = P>,
137 I::IntoIter: ExactSizeIterator,
138{
139 client.with_buf(|buf| {
140 encode_bind(statement, params, "", buf)?;
141 frontend::execute("", 0, buf).map_err(Error::encode)?;
142 frontend::sync(buf);
143 Ok(buf.split().freeze())
144 })
145}
146
147pub fn encode_bind<P, I>(
148 statement: &Statement,
149 params: I,
150 portal: &str,
151 buf: &mut BytesMut,
152) -> Result<(), Error>
153where
154 P: BorrowToSql,
155 I: IntoIterator<Item = P>,
156 I::IntoIter: ExactSizeIterator,
157{
158 let params = params.into_iter();
159
160 assert!(
161 statement.params().len() == params.len(),
162 "expected {} parameters but got {}",
163 statement.params().len(),
164 params.len()
165 );
166
167 let mut error_idx = 0;
168 let r = frontend::bind(
169 portal,
170 statement.name(),
171 Some(1),
172 params.zip(statement.params()).enumerate(),
173 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
174 Ok(IsNull::No) => Ok(opengauss_protocol::IsNull::No),
175 Ok(IsNull::Yes) => Ok(opengauss_protocol::IsNull::Yes),
176 Err(e) => {
177 error_idx = idx;
178 Err(e)
179 }
180 },
181 Some(1),
182 buf,
183 );
184 match r {
185 Ok(()) => Ok(()),
186 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
187 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
188 }
189}
190
191pin_project! {
192 pub struct RowStream {
194 statement: Statement,
195 responses: Responses,
196 #[pin]
197 _p: PhantomPinned,
198 }
199}
200
201impl Stream for RowStream {
202 type Item = Result<Row, Error>;
203
204 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
205 let this = self.project();
206 match ready!(this.responses.poll_next(cx)?) {
207 Message::DataRow(body) => {
208 Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
209 }
210 Message::EmptyQueryResponse
211 | Message::CommandComplete(_)
212 | Message::PortalSuspended => Poll::Ready(None),
213 Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
214 _ => Poll::Ready(Some(Err(Error::unexpected_message()))),
215 }
216 }
217}