1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures_util::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use postgres_protocol::message::backend::{CommandCompleteBody, Message};
11use postgres_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21 T: BorrowToSql,
22{
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_list()
25 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26 .finish()
27 }
28}
29
30pub async fn query<P, I>(
31 client: &InnerClient,
32 statement: Statement,
33 params: I,
34) -> Result<RowStream, Error>
35where
36 P: BorrowToSql,
37 I: IntoIterator<Item = P>,
38 I::IntoIter: ExactSizeIterator,
39{
40 let buf = if log_enabled!(Level::Debug) {
41 let params = params.into_iter().collect::<Vec<_>>();
42 debug!(
43 "executing statement {} with parameters: {:?}",
44 statement.name(),
45 BorrowToSqlParamsDebug(params.as_slice()),
46 );
47 encode(client, &statement, params)?
48 } else {
49 encode(client, &statement, params)?
50 };
51 let responses = start(client, buf).await?;
52 Ok(RowStream {
53 statement,
54 responses,
55 rows_affected: None,
56 _p: PhantomPinned,
57 })
58}
59
60pub async fn query_portal(
61 client: &InnerClient,
62 portal: &Portal,
63 max_rows: i32,
64) -> Result<RowStream, Error> {
65 let buf = client.with_buf(|buf| {
66 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
67 frontend::sync(buf);
68 Ok(buf.split().freeze())
69 })?;
70
71 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
72
73 Ok(RowStream {
74 statement: portal.statement().clone(),
75 responses,
76 rows_affected: None,
77 _p: PhantomPinned,
78 })
79}
80
81pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
83 let rows = body
84 .tag()
85 .map_err(Error::parse)?
86 .rsplit(' ')
87 .next()
88 .unwrap()
89 .parse()
90 .unwrap_or(0);
91 Ok(rows)
92}
93
94pub async fn execute<P, I>(
95 client: &InnerClient,
96 statement: Statement,
97 params: I,
98) -> Result<u64, Error>
99where
100 P: BorrowToSql,
101 I: IntoIterator<Item = P>,
102 I::IntoIter: ExactSizeIterator,
103{
104 let buf = if log_enabled!(Level::Debug) {
105 let params = params.into_iter().collect::<Vec<_>>();
106 debug!(
107 "executing statement {} with parameters: {:?}",
108 statement.name(),
109 BorrowToSqlParamsDebug(params.as_slice()),
110 );
111 encode(client, &statement, params)?
112 } else {
113 encode(client, &statement, params)?
114 };
115 let mut responses = start(client, buf).await?;
116
117 let mut rows = 0;
118 loop {
119 match responses.next().await? {
120 Message::DataRow(_) => {}
121 Message::CommandComplete(body) => {
122 rows = extract_row_affected(&body)?;
123 }
124 Message::EmptyQueryResponse => rows = 0,
125 Message::ReadyForQuery(_) => return Ok(rows),
126 _ => return Err(Error::unexpected_message()),
127 }
128 }
129}
130
131async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
132 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
133
134 match responses.next().await? {
135 Message::BindComplete => {}
136 _ => return Err(Error::unexpected_message()),
137 }
138
139 Ok(responses)
140}
141
142pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
143where
144 P: BorrowToSql,
145 I: IntoIterator<Item = P>,
146 I::IntoIter: ExactSizeIterator,
147{
148 client.with_buf(|buf| {
149 encode_bind(statement, params, "", buf)?;
150 frontend::execute("", 0, buf).map_err(Error::encode)?;
151 frontend::sync(buf);
152 Ok(buf.split().freeze())
153 })
154}
155
156pub fn encode_bind<P, I>(
157 statement: &Statement,
158 params: I,
159 portal: &str,
160 buf: &mut BytesMut,
161) -> Result<(), Error>
162where
163 P: BorrowToSql,
164 I: IntoIterator<Item = P>,
165 I::IntoIter: ExactSizeIterator,
166{
167 let param_types = statement.params();
168 let params = params.into_iter();
169
170 if param_types.len() != params.len() {
171 return Err(Error::parameters(params.len(), param_types.len()));
172 }
173
174 let (param_formats, params): (Vec<_>, Vec<_>) = params
175 .zip(param_types.iter())
176 .map(|(p, ty)| (p.borrow_to_sql().encode_format(ty) as i16, p))
177 .unzip();
178
179 let params = params.into_iter();
180
181 let mut error_idx = 0;
182 let r = frontend::bind(
183 portal,
184 statement.name(),
185 param_formats,
186 params.zip(param_types).enumerate(),
187 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
188 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
189 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
190 Err(e) => {
191 error_idx = idx;
192 Err(e)
193 }
194 },
195 Some(1),
196 buf,
197 );
198 match r {
199 Ok(()) => Ok(()),
200 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
201 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
202 }
203}
204
205pin_project! {
206 pub struct RowStream {
208 statement: Statement,
209 responses: Responses,
210 rows_affected: Option<u64>,
211 #[pin]
212 _p: PhantomPinned,
213 }
214}
215
216impl Stream for RowStream {
217 type Item = Result<Row, Error>;
218
219 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220 let this = self.project();
221 loop {
222 match ready!(this.responses.poll_next(cx)?) {
223 Message::DataRow(body) => {
224 return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
225 }
226 Message::CommandComplete(body) => {
227 *this.rows_affected = Some(extract_row_affected(&body)?);
228 }
229 Message::EmptyQueryResponse | Message::PortalSuspended => {}
230 Message::ReadyForQuery(_) => return Poll::Ready(None),
231 _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
232 }
233 }
234 }
235}
236
237impl RowStream {
238 pub fn rows_affected(&self) -> Option<u64> {
242 self.rows_affected
243 }
244}