1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::Stream;
10use log::{Level, debug, log_enabled};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll, ready};
19
20struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
21
22impl<T> fmt::Debug for BorrowToSqlParamsDebug<'_, T>
23where
24 T: BorrowToSql,
25{
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.debug_list()
28 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
29 .finish()
30 }
31}
32
33pub async fn query<P, I>(
34 client: &InnerClient,
35 statement: Statement,
36 params: I,
37) -> Result<RowStream, Error>
38where
39 P: BorrowToSql,
40 I: IntoIterator<Item = P>,
41 I::IntoIter: ExactSizeIterator,
42{
43 let buf = if log_enabled!(Level::Debug) {
44 let params = params.into_iter().collect::<Vec<_>>();
45 debug!(
46 "executing statement {} with parameters: {:?}",
47 statement.name(),
48 BorrowToSqlParamsDebug(params.as_slice()),
49 );
50 encode(client, &statement, params)?
51 } else {
52 encode(client, &statement, params)?
53 };
54 let responses = start(client, buf).await?;
55 Ok(RowStream {
56 statement,
57 responses,
58 rows_affected: None,
59 })
60}
61
62pub async fn query_typed<P, I>(
63 client: &Arc<InnerClient>,
64 query: &str,
65 params: I,
66) -> Result<RowStream, Error>
67where
68 P: BorrowToSql,
69 I: IntoIterator<Item = (P, Type)>,
70{
71 let buf = {
72 let params = params.into_iter().collect::<Vec<_>>();
73 let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
74
75 client.with_buf(|buf| {
76 frontend::parse("", query, param_oids, buf).map_err(Error::parse)?;
77 encode_bind_raw("", params, "", buf)?;
78 frontend::describe(b'S', "", buf).map_err(Error::encode)?;
79 frontend::execute("", 0, buf).map_err(Error::encode)?;
80 frontend::sync(buf);
81
82 Ok(buf.split().freeze())
83 })?
84 };
85
86 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
87
88 loop {
89 match responses.next().await? {
90 Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
91 Message::NoData => {
92 return Ok(RowStream {
93 statement: Statement::unnamed(vec![], vec![]),
94 responses,
95 rows_affected: None,
96 });
97 }
98 Message::RowDescription(row_description) => {
99 let mut columns: Vec<Column> = vec![];
100 let mut it = row_description.fields();
101 while let Some(field) = it.next().map_err(Error::parse)? {
102 let type_ = get_type(client, field.type_oid()).await?;
103 let column = Column {
104 name: field.name().to_string(),
105 table_oid: Some(field.table_oid()).filter(|n| *n != 0),
106 column_id: Some(field.column_id()).filter(|n| *n != 0),
107 type_modifier: field.type_modifier(),
108 r#type: type_,
109 };
110 columns.push(column);
111 }
112 return Ok(RowStream {
113 statement: Statement::unnamed(vec![], columns),
114 responses,
115 rows_affected: None,
116 });
117 }
118 _ => return Err(Error::unexpected_message()),
119 }
120 }
121}
122
123pub async fn execute_typed<P, I>(
124 client: &Arc<InnerClient>,
125 query: &str,
126 params: I,
127) -> Result<u64, Error>
128where
129 P: BorrowToSql,
130 I: IntoIterator<Item = (P, Type)>,
131{
132 let buf = {
133 let params = params.into_iter().collect::<Vec<_>>();
134 let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
135
136 client.with_buf(|buf| {
137 frontend::parse("", query, param_oids, buf).map_err(Error::parse)?;
138 encode_bind_raw("", params, "", buf)?;
139 frontend::describe(b'S', "", buf).map_err(Error::encode)?;
140 frontend::execute("", 0, buf).map_err(Error::encode)?;
141 frontend::sync(buf);
142
143 Ok(buf.split().freeze())
144 })?
145 };
146
147 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
148
149 let mut rows = 0;
150
151 loop {
152 match responses.next().await? {
153 Message::ParseComplete
154 | Message::BindComplete
155 | Message::ParameterDescription(_)
156 | Message::RowDescription(_) => {}
157 Message::NoData => {
158 rows = 0;
159 }
160
161 Message::DataRow(_) => {}
162 Message::CommandComplete(body) => {
163 rows = extract_row_affected(&body)?;
164 }
165
166 Message::EmptyQueryResponse => rows = 0,
167 Message::ReadyForQuery(_) => return Ok(rows),
168 _ => {
169 return Err(Error::unexpected_message());
170 }
171 }
172 }
173}
174
175pub async fn query_portal(
176 client: &InnerClient,
177 portal: &Portal,
178 max_rows: i32,
179) -> Result<RowStream, Error> {
180 let buf = client.with_buf(|buf| {
181 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
182 frontend::sync(buf);
183 Ok(buf.split().freeze())
184 })?;
185
186 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
187
188 Ok(RowStream {
189 statement: portal.statement().clone(),
190 responses,
191 rows_affected: None,
192 })
193}
194
195pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
197 let rows = body
198 .tag()
199 .map_err(Error::parse)?
200 .rsplit(' ')
201 .next()
202 .unwrap()
203 .parse()
204 .unwrap_or(0);
205 Ok(rows)
206}
207
208pub async fn execute<P, I>(
209 client: &InnerClient,
210 statement: Statement,
211 params: I,
212) -> Result<u64, Error>
213where
214 P: BorrowToSql,
215 I: IntoIterator<Item = P>,
216 I::IntoIter: ExactSizeIterator,
217{
218 let buf = if log_enabled!(Level::Debug) {
219 let params = params.into_iter().collect::<Vec<_>>();
220 debug!(
221 "executing statement {} with parameters: {:?}",
222 statement.name(),
223 BorrowToSqlParamsDebug(params.as_slice()),
224 );
225 encode(client, &statement, params)?
226 } else {
227 encode(client, &statement, params)?
228 };
229 let mut responses = start(client, buf).await?;
230
231 let mut rows = 0;
232 loop {
233 match responses.next().await? {
234 Message::DataRow(_) => {}
235 Message::CommandComplete(body) => {
236 rows = extract_row_affected(&body)?;
237 }
238 Message::EmptyQueryResponse => rows = 0,
239 Message::ReadyForQuery(_) => return Ok(rows),
240 _ => return Err(Error::unexpected_message()),
241 }
242 }
243}
244
245async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
246 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
247
248 match responses.next().await? {
249 Message::BindComplete => {}
250 _ => return Err(Error::unexpected_message()),
251 }
252
253 Ok(responses)
254}
255
256pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
257where
258 P: BorrowToSql,
259 I: IntoIterator<Item = P>,
260 I::IntoIter: ExactSizeIterator,
261{
262 client.with_buf(|buf| {
263 encode_bind(statement, params, "", buf)?;
264 frontend::execute("", 0, buf).map_err(Error::encode)?;
265 frontend::sync(buf);
266 Ok(buf.split().freeze())
267 })
268}
269
270pub fn encode_bind<P, I>(
271 statement: &Statement,
272 params: I,
273 portal: &str,
274 buf: &mut BytesMut,
275) -> Result<(), Error>
276where
277 P: BorrowToSql,
278 I: IntoIterator<Item = P>,
279 I::IntoIter: ExactSizeIterator,
280{
281 let params = params.into_iter();
282 if params.len() != statement.params().len() {
283 return Err(Error::parameters(params.len(), statement.params().len()));
284 }
285
286 encode_bind_raw(
287 statement.name(),
288 params.zip(statement.params().iter().cloned()),
289 portal,
290 buf,
291 )
292}
293
294fn encode_bind_raw<P, I>(
295 statement_name: &str,
296 params: I,
297 portal: &str,
298 buf: &mut BytesMut,
299) -> Result<(), Error>
300where
301 P: BorrowToSql,
302 I: IntoIterator<Item = (P, Type)>,
303 I::IntoIter: ExactSizeIterator,
304{
305 let (param_formats, params): (Vec<_>, Vec<_>) = params
306 .into_iter()
307 .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
308 .unzip();
309
310 let mut error_idx = 0;
311 let r = frontend::bind(
312 portal,
313 statement_name,
314 param_formats,
315 params.into_iter().enumerate(),
316 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
317 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
318 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
319 Err(e) => {
320 error_idx = idx;
321 Err(e)
322 }
323 },
324 Some(1),
325 buf,
326 );
327 match r {
328 Ok(()) => Ok(()),
329 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
330 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
331 }
332}
333
334pin_project! {
335 #[project(!Unpin)]
337 pub struct RowStream {
338 statement: Statement,
339 responses: Responses,
340 rows_affected: Option<u64>,
341 }
342}
343
344impl Stream for RowStream {
345 type Item = Result<Row, Error>;
346
347 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
348 let this = self.project();
349 loop {
350 match ready!(this.responses.poll_next(cx)?) {
351 Message::DataRow(body) => {
352 return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)));
353 }
354 Message::CommandComplete(body) => {
355 *this.rows_affected = Some(extract_row_affected(&body)?);
356 }
357 Message::EmptyQueryResponse | Message::PortalSuspended => {}
358 Message::ReadyForQuery(_) => return Poll::Ready(None),
359 _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
360 }
361 }
362 }
363}
364
365impl RowStream {
366 pub fn rows_affected(&self) -> Option<u64> {
370 self.rows_affected
371 }
372}
373
374pub async fn sync(client: &InnerClient) -> Result<(), Error> {
375 let buf = Bytes::from_static(b"S\0\0\0\x04");
376 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
377
378 match responses.next().await? {
379 Message::ReadyForQuery(_) => Ok(()),
380 _ => Err(Error::unexpected_message()),
381 }
382}