sqlx_postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8 ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13 PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28 persistent: bool,
29) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
30 let id = if persistent {
31 let id = conn.inner.next_statement_id;
32 conn.inner.next_statement_id = id.next();
33 id
34 } else {
35 StatementId::UNNAMED
36 };
37
38 let mut param_types = Vec::with_capacity(parameters.len());
43
44 for ty in parameters {
45 param_types.push(conn.resolve_type_id(&ty.0).await?);
46 }
47
48 conn.wait_until_ready().await?;
50
51 conn.inner.stream.write_msg(Parse {
53 param_types: ¶m_types,
54 query: sql,
55 statement: id,
56 })?;
57
58 if metadata.is_none() {
59 conn.inner
61 .stream
62 .write_msg(message::Describe::Statement(id))?;
63 }
64
65 conn.write_sync();
67 conn.inner.stream.flush().await?;
68
69 conn.inner.stream.recv_expect::<ParseComplete>().await?;
71
72 let metadata = if let Some(metadata) = metadata {
73 conn.recv_ready_for_query().await?;
75
76 metadata
78 } else {
79 let parameters = recv_desc_params(conn).await?;
80
81 let rows = recv_desc_rows(conn).await?;
82
83 conn.recv_ready_for_query().await?;
85
86 let parameters = conn.handle_parameter_description(parameters).await?;
87
88 let (columns, column_names) = conn.handle_row_description(rows, true).await?;
89
90 conn.wait_until_ready().await?;
93
94 Arc::new(PgStatementMetadata {
95 parameters,
96 columns,
97 column_names: Arc::new(column_names),
98 })
99 };
100
101 Ok((id, metadata))
102}
103
104async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
105 conn.inner.stream.recv_expect().await
106}
107
108async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
109 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
110 message if message.format == BackendMessageFormat::RowDescription => {
112 Some(message.decode()?)
113 }
114
115 message if message.format == BackendMessageFormat::NoData => None,
117
118 message => {
119 return Err(err_protocol!(
120 "expecting RowDescription or NoData but received {:?}",
121 message.format
122 ));
123 }
124 };
125
126 Ok(rows)
127}
128
129impl PgConnection {
130 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
132 while count > 0 {
134 match self.inner.stream.recv().await? {
135 message if message.format == BackendMessageFormat::PortalSuspended => {
136 }
139
140 message if message.format == BackendMessageFormat::CloseComplete => {
141 count -= 1;
143 }
144
145 message => {
146 return Err(err_protocol!(
147 "expecting PortalSuspended or CloseComplete but received {:?}",
148 message.format
149 ));
150 }
151 }
152 }
153
154 Ok(())
155 }
156
157 #[inline(always)]
158 pub(crate) fn write_sync(&mut self) {
159 self.inner
160 .stream
161 .write_msg(message::Sync)
162 .expect("BUG: Sync should not be too big for protocol");
163
164 self.inner.pending_ready_for_query_count += 1;
166 }
167
168 async fn get_or_prepare<'a>(
169 &mut self,
170 sql: &str,
171 parameters: &[PgTypeInfo],
172 persistent: bool,
173 metadata: Option<Arc<PgStatementMetadata>>,
176 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
177 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
178 return Ok((*statement).clone());
179 }
180
181 let statement = prepare(self, sql, parameters, metadata, persistent).await?;
182
183 if persistent && self.inner.cache_statement.is_enabled() {
184 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
185 self.inner.stream.write_msg(Close::Statement(id))?;
186 self.write_sync();
187
188 self.inner.stream.flush().await?;
189
190 self.wait_for_close_complete(1).await?;
191 self.recv_ready_for_query().await?;
192 }
193 }
194
195 Ok(statement)
196 }
197
198 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
199 &'c mut self,
200 query: &'q str,
201 arguments: Option<PgArguments>,
202 persistent: bool,
203 metadata_opt: Option<Arc<PgStatementMetadata>>,
204 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
205 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
206
207 self.wait_until_ready().await?;
209
210 let mut metadata: Arc<PgStatementMetadata>;
211
212 let format = if let Some(mut arguments) = arguments {
213 let num_params = u16::try_from(arguments.len()).map_err(|_| {
220 err_protocol!(
221 "PgConnection::run(): too many arguments for query: {}",
222 arguments.len()
223 )
224 })?;
225
226 let (statement, metadata_) = self
229 .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
230 .await?;
231
232 metadata = metadata_;
233
234 arguments.apply_patches(self, &metadata.parameters).await?;
236
237 self.wait_until_ready().await?;
239
240 self.inner.stream.write_msg(Bind {
242 portal: PortalId::UNNAMED,
243 statement,
244 formats: &[PgValueFormat::Binary],
245 num_params,
246 params: &arguments.buffer,
247 result_formats: &[PgValueFormat::Binary],
248 })?;
249
250 self.inner.stream.write_msg(message::Execute {
253 portal: PortalId::UNNAMED,
254 limit: 0,
257 })?;
258 self.inner
268 .stream
269 .write_msg(Close::Portal(PortalId::UNNAMED))?;
270
271 self.write_sync();
277
278 PgValueFormat::Binary
280 } else {
281 self.inner.stream.write_msg(Query(query))?;
283 self.inner.pending_ready_for_query_count += 1;
284
285 metadata = Arc::new(PgStatementMetadata::default());
287
288 PgValueFormat::Text
290 };
291
292 self.inner.stream.flush().await?;
293
294 Ok(try_stream! {
295 loop {
296 let message = self.inner.stream.recv().await?;
297
298 match message.format {
299 BackendMessageFormat::BindComplete
300 | BackendMessageFormat::ParseComplete
301 | BackendMessageFormat::ParameterDescription
302 | BackendMessageFormat::NoData
303 | BackendMessageFormat::CloseComplete
305 => {
306 }
308
309 BackendMessageFormat::CommandComplete => {
314 let cc: CommandComplete = message.decode()?;
316
317 let rows_affected = cc.rows_affected();
318 logger.increase_rows_affected(rows_affected);
319 r#yield!(Either::Left(PgQueryResult {
320 rows_affected,
321 }));
322 }
323
324 BackendMessageFormat::EmptyQueryResponse => {
325 }
327
328 BackendMessageFormat::PortalSuspended => {}
332
333 BackendMessageFormat::RowDescription => {
334 let (columns, column_names) = self
336 .handle_row_description(Some(message.decode()?), false)
337 .await?;
338
339 metadata = Arc::new(PgStatementMetadata {
340 column_names: Arc::new(column_names),
341 columns,
342 parameters: Vec::default(),
343 });
344 }
345
346 BackendMessageFormat::DataRow => {
347 logger.increment_rows_returned();
348
349 let data: DataRow = message.decode()?;
351 let row = PgRow {
352 data,
353 format,
354 metadata: Arc::clone(&metadata),
355 };
356
357 r#yield!(Either::Right(row));
358 }
359
360 BackendMessageFormat::ReadyForQuery => {
361 self.handle_ready_for_query(message)?;
363 break;
364 }
365
366 _ => {
367 return Err(err_protocol!(
368 "execute: unexpected message: {:?}",
369 message.format
370 ));
371 }
372 }
373 }
374
375 Ok(())
376 })
377 }
378}
379
380impl<'c> Executor<'c> for &'c mut PgConnection {
381 type Database = Postgres;
382
383 fn fetch_many<'e, 'q, E>(
384 self,
385 mut query: E,
386 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
387 where
388 'c: 'e,
389 E: Execute<'q, Self::Database>,
390 'q: 'e,
391 E: 'q,
392 {
393 let sql = query.sql();
394 #[allow(clippy::map_clone)]
396 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
397 let arguments = query.take_arguments().map_err(Error::Encode);
398 let persistent = query.persistent();
399
400 Box::pin(try_stream! {
401 let arguments = arguments?;
402 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
403
404 while let Some(v) = s.try_next().await? {
405 r#yield!(v);
406 }
407
408 Ok(())
409 })
410 }
411
412 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
413 where
414 'c: 'e,
415 E: Execute<'q, Self::Database>,
416 'q: 'e,
417 E: 'q,
418 {
419 let sql = query.sql();
420 #[allow(clippy::map_clone)]
422 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
423 let arguments = query.take_arguments().map_err(Error::Encode);
424 let persistent = query.persistent();
425
426 Box::pin(async move {
427 let arguments = arguments?;
428 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
429
430 let mut ret = None;
435 while let Some(result) = s.try_next().await? {
436 match result {
437 Either::Right(r) if ret.is_none() => ret = Some(r),
438 _ => {}
439 }
440 }
441 Ok(ret)
442 })
443 }
444
445 fn prepare_with<'e, 'q: 'e>(
446 self,
447 sql: &'q str,
448 parameters: &'e [PgTypeInfo],
449 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
450 where
451 'c: 'e,
452 {
453 Box::pin(async move {
454 self.wait_until_ready().await?;
455
456 let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
457
458 Ok(PgStatement {
459 sql: Cow::Borrowed(sql),
460 metadata,
461 })
462 })
463 }
464
465 fn describe<'e, 'q: 'e>(
466 self,
467 sql: &'q str,
468 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
469 where
470 'c: 'e,
471 {
472 Box::pin(async move {
473 self.wait_until_ready().await?;
474
475 let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
476
477 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
478
479 Ok(Describe {
480 columns: metadata.columns.clone(),
481 nullable,
482 parameters: Some(Either::Left(metadata.parameters.clone())),
483 })
484 })
485 }
486}