sqlx_postgres/connection/
executor.rs1use crate::error::Error;
2use crate::executor::{Execute, Executor};
3use crate::io::{PortalId, StatementId};
4use crate::logger::QueryLogger;
5use crate::message::{
6 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
7 ParseComplete, RowDescription,
8};
9use crate::statement::PgStatementMetadata;
10use crate::{
11 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
12 PgValueFormat, Postgres,
13};
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::TryStreamExt;
18use sqlx_core::arguments::Arguments;
19use sqlx_core::sql_str::SqlStr;
20use sqlx_core::Either;
21use std::{pin::pin, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 arg_types: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28 persistent: bool,
29 resolve_column_origin: bool,
30) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
31 let id = if persistent {
32 let id = conn.inner.next_statement_id;
33 conn.inner.next_statement_id = id.next();
34 id
35 } else {
36 StatementId::UNNAMED
37 };
38
39 let param_types = conn.resolve_types(arg_types).await?;
43
44 conn.wait_until_ready().await?;
46
47 conn.inner.stream.write_msg(Parse {
49 param_types: ¶m_types,
50 query: sql,
51 statement: id,
52 })?;
53
54 if metadata.is_none() {
55 conn.inner
57 .stream
58 .write_msg(message::Describe::Statement(id))?;
59 }
60
61 conn.write_sync();
63 conn.inner.stream.flush().await?;
64
65 conn.inner.stream.recv_expect::<ParseComplete>().await?;
67
68 let metadata = if let Some(metadata) = metadata {
69 conn.recv_ready_for_query().await?;
71
72 metadata
74 } else {
75 let parameters = recv_desc_params(conn).await?;
76
77 let row_desc = recv_desc_rows(conn).await?;
78
79 conn.recv_ready_for_query().await?;
81
82 let metadata = conn
83 .resolve_statement_metadata::<true>(Some(parameters), row_desc, resolve_column_origin)
84 .await?;
85
86 conn.wait_until_ready().await?;
89
90 metadata
91 };
92
93 Ok((id, metadata))
94}
95
96async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
97 conn.inner.stream.recv_expect().await
98}
99
100async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
101 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
102 message if message.format == BackendMessageFormat::RowDescription => {
104 Some(message.decode()?)
105 }
106
107 message if message.format == BackendMessageFormat::NoData => None,
109
110 message => {
111 return Err(err_protocol!(
112 "expecting RowDescription or NoData but received {:?}",
113 message.format
114 ));
115 }
116 };
117
118 Ok(rows)
119}
120
121impl PgConnection {
122 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
124 while count > 0 {
126 match self.inner.stream.recv().await? {
127 message if message.format == BackendMessageFormat::PortalSuspended => {
128 }
131
132 message if message.format == BackendMessageFormat::CloseComplete => {
133 count -= 1;
135 }
136
137 message => {
138 return Err(err_protocol!(
139 "expecting PortalSuspended or CloseComplete but received {:?}",
140 message.format
141 ));
142 }
143 }
144 }
145
146 Ok(())
147 }
148
149 #[inline(always)]
150 pub(crate) fn write_sync(&mut self) {
151 self.inner
152 .stream
153 .write_msg(message::Sync)
154 .expect("BUG: Sync should not be too big for protocol");
155
156 self.inner.pending_ready_for_query_count += 1;
158 }
159
160 async fn get_or_prepare(
161 &mut self,
162 sql: &str,
163 parameters: &[PgTypeInfo],
164 persistent: bool,
165 metadata: Option<Arc<PgStatementMetadata>>,
168 resolve_column_origin: bool,
169 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
170 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
171 return Ok((*statement).clone());
172 }
173
174 let statement = prepare(
175 self,
176 sql,
177 parameters,
178 metadata,
179 persistent,
180 resolve_column_origin,
181 )
182 .await?;
183
184 if persistent && self.inner.cache_statement.is_enabled() {
185 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
186 self.inner.stream.write_msg(Close::Statement(id))?;
187 self.write_sync();
188
189 self.inner.stream.flush().await?;
190
191 self.wait_for_close_complete(1).await?;
192 self.recv_ready_for_query().await?;
193 }
194 }
195
196 Ok(statement)
197 }
198
199 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
200 &'c mut self,
201 query: SqlStr,
202 arguments: Option<PgArguments>,
203 persistent: bool,
204 metadata_opt: Option<Arc<PgStatementMetadata>>,
205 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
206 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
207 let sql = logger.sql().as_str();
208
209 self.wait_until_ready().await?;
211
212 let mut metadata: Arc<PgStatementMetadata>;
213
214 let format = if let Some(mut arguments) = arguments {
215 let num_params = u16::try_from(arguments.len()).map_err(|_| {
222 err_protocol!(
223 "PgConnection::run(): too many arguments for query: {}",
224 arguments.len()
225 )
226 })?;
227
228 let (statement, metadata_) = self
231 .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
232 .await?;
233
234 metadata = metadata_;
235
236 arguments.apply_patches(self, &metadata.parameters).await?;
238
239 self.wait_until_ready().await?;
241
242 self.inner.stream.write_msg(Bind {
244 portal: PortalId::UNNAMED,
245 statement,
246 formats: &[PgValueFormat::Binary],
247 num_params,
248 params: &arguments.buffer,
249 result_formats: &[PgValueFormat::Binary],
250 })?;
251
252 self.inner.stream.write_msg(message::Execute {
255 portal: PortalId::UNNAMED,
256 limit: 0,
259 })?;
260 self.inner
270 .stream
271 .write_msg(Close::Portal(PortalId::UNNAMED))?;
272
273 self.write_sync();
279
280 PgValueFormat::Binary
282 } else {
283 self.queue_simple_query(sql)?;
285
286 metadata = Arc::new(PgStatementMetadata::default());
288
289 PgValueFormat::Text
291 };
292
293 self.inner.stream.flush().await?;
294
295 Ok(try_stream! {
296 loop {
297 let message = self.inner.stream.recv().await?;
298
299 match message.format {
300 BackendMessageFormat::BindComplete
301 | BackendMessageFormat::ParseComplete
302 | BackendMessageFormat::ParameterDescription
303 | BackendMessageFormat::NoData
304 | BackendMessageFormat::CloseComplete
306 => {
307 }
309
310 BackendMessageFormat::CommandComplete => {
315 let cc: CommandComplete = message.decode()?;
317
318 let rows_affected = cc.rows_affected();
319 logger.increase_rows_affected(rows_affected);
320 r#yield!(Either::Left(PgQueryResult {
321 rows_affected,
322 }));
323 }
324
325 BackendMessageFormat::EmptyQueryResponse => {
326 }
328
329 BackendMessageFormat::PortalSuspended => {}
333
334 BackendMessageFormat::RowDescription => {
336 let new_metadata = self.resolve_statement_metadata::<false>(
337 None,
338 Some(message.decode()?),
339 false,
340 ).await?;
341
342 metadata = new_metadata;
343 }
344
345 BackendMessageFormat::DataRow => {
346 logger.increment_rows_returned();
347
348 let data: DataRow = message.decode()?;
350 let row = PgRow {
351 data,
352 format,
353 metadata: Arc::clone(&metadata),
354 };
355
356 r#yield!(Either::Right(row));
357 }
358
359 BackendMessageFormat::ReadyForQuery => {
360 self.handle_ready_for_query(message)?;
362 break;
363 }
364
365 _ => {
366 return Err(err_protocol!(
367 "execute: unexpected message: {:?}",
368 message.format
369 ));
370 }
371 }
372 }
373
374 Ok(())
375 })
376 }
377}
378
379impl<'c> Executor<'c> for &'c mut PgConnection {
380 type Database = Postgres;
381
382 fn fetch_many<'e, 'q, E>(
383 self,
384 mut query: E,
385 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
386 where
387 'c: 'e,
388 E: Execute<'q, Self::Database>,
389 'q: 'e,
390 E: 'q,
391 {
392 #[allow(clippy::map_clone)]
394 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
395 let arguments = query.take_arguments().map_err(Error::Encode);
396 let persistent = query.persistent();
397 let sql = query.sql();
398
399 Box::pin(try_stream! {
400 let arguments = arguments?;
401 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
402
403 while let Some(v) = s.try_next().await? {
404 r#yield!(v);
405 }
406
407 Ok(())
408 })
409 }
410
411 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
412 where
413 'c: 'e,
414 E: Execute<'q, Self::Database>,
415 'q: 'e,
416 E: 'q,
417 {
418 #[allow(clippy::map_clone)]
420 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
421 let arguments = query.take_arguments().map_err(Error::Encode);
422 let persistent = query.persistent();
423
424 Box::pin(async move {
425 let sql = query.sql();
426 let arguments = arguments?;
427 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
428
429 let mut ret = None;
434 while let Some(result) = s.try_next().await? {
435 match result {
436 Either::Right(r) if ret.is_none() => ret = Some(r),
437 _ => {}
438 }
439 }
440 Ok(ret)
441 })
442 }
443
444 fn prepare_with<'e>(
445 self,
446 sql: SqlStr,
447 parameters: &'e [PgTypeInfo],
448 ) -> BoxFuture<'e, Result<PgStatement, Error>>
449 where
450 'c: 'e,
451 {
452 Box::pin(async move {
453 self.wait_until_ready().await?;
454
455 let (_, metadata) = self
456 .get_or_prepare(sql.as_str(), parameters, true, None, true)
457 .await?;
458
459 Ok(PgStatement { sql, metadata })
460 })
461 }
462
463 #[cfg(feature = "offline")]
464 fn describe<'e>(
465 self,
466 sql: SqlStr,
467 ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
468 where
469 'c: 'e,
470 {
471 Box::pin(async move {
472 self.wait_until_ready().await?;
473
474 let (stmt_id, metadata) = self
475 .get_or_prepare(sql.as_str(), &[], true, None, true)
476 .await?;
477
478 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
479
480 Ok(crate::describe::Describe {
481 columns: metadata.columns.clone(),
482 nullable,
483 parameters: Some(Either::Left(metadata.parameters.clone())),
484 })
485 })
486 }
487}