sqlx_postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8 ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13 PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::sql_str::SqlStr;
21use sqlx_core::Either;
22use std::{pin::pin, sync::Arc};
23
24async fn prepare(
25 conn: &mut PgConnection,
26 sql: &str,
27 parameters: &[PgTypeInfo],
28 metadata: Option<Arc<PgStatementMetadata>>,
29 persistent: bool,
30 fetch_column_origin: bool,
31) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
32 let id = if persistent {
33 let id = conn.inner.next_statement_id;
34 conn.inner.next_statement_id = id.next();
35 id
36 } else {
37 StatementId::UNNAMED
38 };
39
40 let mut param_types = Vec::with_capacity(parameters.len());
45
46 for ty in parameters {
47 param_types.push(conn.resolve_type_id(&ty.0).await?);
48 }
49
50 conn.wait_until_ready().await?;
52
53 conn.inner.stream.write_msg(Parse {
55 param_types: ¶m_types,
56 query: sql,
57 statement: id,
58 })?;
59
60 if metadata.is_none() {
61 conn.inner
63 .stream
64 .write_msg(message::Describe::Statement(id))?;
65 }
66
67 conn.write_sync();
69 conn.inner.stream.flush().await?;
70
71 conn.inner.stream.recv_expect::<ParseComplete>().await?;
73
74 let metadata = if let Some(metadata) = metadata {
75 conn.recv_ready_for_query().await?;
77
78 metadata
80 } else {
81 let parameters = recv_desc_params(conn).await?;
82
83 let rows = recv_desc_rows(conn).await?;
84
85 conn.recv_ready_for_query().await?;
87
88 let parameters = conn.handle_parameter_description(parameters).await?;
89
90 let (columns, column_names) = conn
91 .handle_row_description(rows, true, fetch_column_origin)
92 .await?;
93
94 conn.wait_until_ready().await?;
97
98 Arc::new(PgStatementMetadata {
99 parameters,
100 columns,
101 column_names: Arc::new(column_names),
102 })
103 };
104
105 Ok((id, metadata))
106}
107
108async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
109 conn.inner.stream.recv_expect().await
110}
111
112async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
113 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
114 message if message.format == BackendMessageFormat::RowDescription => {
116 Some(message.decode()?)
117 }
118
119 message if message.format == BackendMessageFormat::NoData => None,
121
122 message => {
123 return Err(err_protocol!(
124 "expecting RowDescription or NoData but received {:?}",
125 message.format
126 ));
127 }
128 };
129
130 Ok(rows)
131}
132
133impl PgConnection {
134 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
136 while count > 0 {
138 match self.inner.stream.recv().await? {
139 message if message.format == BackendMessageFormat::PortalSuspended => {
140 }
143
144 message if message.format == BackendMessageFormat::CloseComplete => {
145 count -= 1;
147 }
148
149 message => {
150 return Err(err_protocol!(
151 "expecting PortalSuspended or CloseComplete but received {:?}",
152 message.format
153 ));
154 }
155 }
156 }
157
158 Ok(())
159 }
160
161 #[inline(always)]
162 pub(crate) fn write_sync(&mut self) {
163 self.inner
164 .stream
165 .write_msg(message::Sync)
166 .expect("BUG: Sync should not be too big for protocol");
167
168 self.inner.pending_ready_for_query_count += 1;
170 }
171
172 async fn get_or_prepare(
173 &mut self,
174 sql: &str,
175 parameters: &[PgTypeInfo],
176 persistent: bool,
177 metadata: Option<Arc<PgStatementMetadata>>,
180 fetch_column_origin: bool,
181 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
182 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
183 return Ok((*statement).clone());
184 }
185
186 let statement = prepare(
187 self,
188 sql,
189 parameters,
190 metadata,
191 persistent,
192 fetch_column_origin,
193 )
194 .await?;
195
196 if persistent && self.inner.cache_statement.is_enabled() {
197 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
198 self.inner.stream.write_msg(Close::Statement(id))?;
199 self.write_sync();
200
201 self.inner.stream.flush().await?;
202
203 self.wait_for_close_complete(1).await?;
204 self.recv_ready_for_query().await?;
205 }
206 }
207
208 Ok(statement)
209 }
210
211 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
212 &'c mut self,
213 query: SqlStr,
214 arguments: Option<PgArguments>,
215 persistent: bool,
216 metadata_opt: Option<Arc<PgStatementMetadata>>,
217 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
218 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
219 let sql = logger.sql().as_str();
220
221 self.wait_until_ready().await?;
223
224 let mut metadata: Arc<PgStatementMetadata>;
225
226 let format = if let Some(mut arguments) = arguments {
227 let num_params = u16::try_from(arguments.len()).map_err(|_| {
234 err_protocol!(
235 "PgConnection::run(): too many arguments for query: {}",
236 arguments.len()
237 )
238 })?;
239
240 let (statement, metadata_) = self
243 .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
244 .await?;
245
246 metadata = metadata_;
247
248 arguments.apply_patches(self, &metadata.parameters).await?;
250
251 self.wait_until_ready().await?;
253
254 self.inner.stream.write_msg(Bind {
256 portal: PortalId::UNNAMED,
257 statement,
258 formats: &[PgValueFormat::Binary],
259 num_params,
260 params: &arguments.buffer,
261 result_formats: &[PgValueFormat::Binary],
262 })?;
263
264 self.inner.stream.write_msg(message::Execute {
267 portal: PortalId::UNNAMED,
268 limit: 0,
271 })?;
272 self.inner
282 .stream
283 .write_msg(Close::Portal(PortalId::UNNAMED))?;
284
285 self.write_sync();
291
292 PgValueFormat::Binary
294 } else {
295 self.inner.stream.write_msg(Query(sql))?;
297 self.inner.pending_ready_for_query_count += 1;
298
299 metadata = Arc::new(PgStatementMetadata::default());
301
302 PgValueFormat::Text
304 };
305
306 self.inner.stream.flush().await?;
307
308 Ok(try_stream! {
309 loop {
310 let message = self.inner.stream.recv().await?;
311
312 match message.format {
313 BackendMessageFormat::BindComplete
314 | BackendMessageFormat::ParseComplete
315 | BackendMessageFormat::ParameterDescription
316 | BackendMessageFormat::NoData
317 | BackendMessageFormat::CloseComplete
319 => {
320 }
322
323 BackendMessageFormat::CommandComplete => {
328 let cc: CommandComplete = message.decode()?;
330
331 let rows_affected = cc.rows_affected();
332 logger.increase_rows_affected(rows_affected);
333 r#yield!(Either::Left(PgQueryResult {
334 rows_affected,
335 }));
336 }
337
338 BackendMessageFormat::EmptyQueryResponse => {
339 }
341
342 BackendMessageFormat::PortalSuspended => {}
346
347 BackendMessageFormat::RowDescription => {
348 let (columns, column_names) = self
350 .handle_row_description(Some(message.decode()?), false, false)
351 .await?;
352
353 metadata = Arc::new(PgStatementMetadata {
354 column_names: Arc::new(column_names),
355 columns,
356 parameters: Vec::default(),
357 });
358 }
359
360 BackendMessageFormat::DataRow => {
361 logger.increment_rows_returned();
362
363 let data: DataRow = message.decode()?;
365 let row = PgRow {
366 data,
367 format,
368 metadata: Arc::clone(&metadata),
369 };
370
371 r#yield!(Either::Right(row));
372 }
373
374 BackendMessageFormat::ReadyForQuery => {
375 self.handle_ready_for_query(message)?;
377 break;
378 }
379
380 _ => {
381 return Err(err_protocol!(
382 "execute: unexpected message: {:?}",
383 message.format
384 ));
385 }
386 }
387 }
388
389 Ok(())
390 })
391 }
392}
393
394impl<'c> Executor<'c> for &'c mut PgConnection {
395 type Database = Postgres;
396
397 fn fetch_many<'e, 'q, E>(
398 self,
399 mut query: E,
400 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
401 where
402 'c: 'e,
403 E: Execute<'q, Self::Database>,
404 'q: 'e,
405 E: 'q,
406 {
407 #[allow(clippy::map_clone)]
409 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
410 let arguments = query.take_arguments().map_err(Error::Encode);
411 let persistent = query.persistent();
412 let sql = query.sql();
413
414 Box::pin(try_stream! {
415 let arguments = arguments?;
416 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
417
418 while let Some(v) = s.try_next().await? {
419 r#yield!(v);
420 }
421
422 Ok(())
423 })
424 }
425
426 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
427 where
428 'c: 'e,
429 E: Execute<'q, Self::Database>,
430 'q: 'e,
431 E: 'q,
432 {
433 #[allow(clippy::map_clone)]
435 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
436 let arguments = query.take_arguments().map_err(Error::Encode);
437 let persistent = query.persistent();
438
439 Box::pin(async move {
440 let sql = query.sql();
441 let arguments = arguments?;
442 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
443
444 let mut ret = None;
449 while let Some(result) = s.try_next().await? {
450 match result {
451 Either::Right(r) if ret.is_none() => ret = Some(r),
452 _ => {}
453 }
454 }
455 Ok(ret)
456 })
457 }
458
459 fn prepare_with<'e>(
460 self,
461 sql: SqlStr,
462 parameters: &'e [PgTypeInfo],
463 ) -> BoxFuture<'e, Result<PgStatement, Error>>
464 where
465 'c: 'e,
466 {
467 Box::pin(async move {
468 self.wait_until_ready().await?;
469
470 let (_, metadata) = self
471 .get_or_prepare(sql.as_str(), parameters, true, None, true)
472 .await?;
473
474 Ok(PgStatement { sql, metadata })
475 })
476 }
477
478 fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
479 where
480 'c: 'e,
481 {
482 Box::pin(async move {
483 self.wait_until_ready().await?;
484
485 let (stmt_id, metadata) = self
486 .get_or_prepare(sql.as_str(), &[], true, None, true)
487 .await?;
488
489 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
490
491 Ok(Describe {
492 columns: metadata.columns.clone(),
493 nullable,
494 parameters: Some(Either::Left(metadata.parameters.clone())),
495 })
496 })
497 }
498}