sqlx_postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8 ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13 PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
29 let id = conn.inner.next_statement_id;
30 conn.inner.next_statement_id = id.next();
31
32 let mut param_types = Vec::with_capacity(parameters.len());
37
38 for ty in parameters {
39 param_types.push(conn.resolve_type_id(&ty.0).await?);
40 }
41
42 conn.wait_until_ready().await?;
44
45 conn.inner.stream.write_msg(Parse {
47 param_types: ¶m_types,
48 query: sql,
49 statement: id,
50 })?;
51
52 if metadata.is_none() {
53 conn.inner
55 .stream
56 .write_msg(message::Describe::Statement(id))?;
57 }
58
59 conn.write_sync();
61 conn.inner.stream.flush().await?;
62
63 conn.inner.stream.recv_expect::<ParseComplete>().await?;
65
66 let metadata = if let Some(metadata) = metadata {
67 conn.recv_ready_for_query().await?;
69
70 metadata
72 } else {
73 let parameters = recv_desc_params(conn).await?;
74
75 let rows = recv_desc_rows(conn).await?;
76
77 conn.recv_ready_for_query().await?;
79
80 let parameters = conn.handle_parameter_description(parameters).await?;
81
82 let (columns, column_names) = conn.handle_row_description(rows, true).await?;
83
84 conn.wait_until_ready().await?;
87
88 Arc::new(PgStatementMetadata {
89 parameters,
90 columns,
91 column_names: Arc::new(column_names),
92 })
93 };
94
95 Ok((id, metadata))
96}
97
98async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
99 conn.inner.stream.recv_expect().await
100}
101
102async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
103 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
104 message if message.format == BackendMessageFormat::RowDescription => {
106 Some(message.decode()?)
107 }
108
109 message if message.format == BackendMessageFormat::NoData => None,
111
112 message => {
113 return Err(err_protocol!(
114 "expecting RowDescription or NoData but received {:?}",
115 message.format
116 ));
117 }
118 };
119
120 Ok(rows)
121}
122
123impl PgConnection {
124 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
126 while count > 0 {
128 match self.inner.stream.recv().await? {
129 message if message.format == BackendMessageFormat::PortalSuspended => {
130 }
133
134 message if message.format == BackendMessageFormat::CloseComplete => {
135 count -= 1;
137 }
138
139 message => {
140 return Err(err_protocol!(
141 "expecting PortalSuspended or CloseComplete but received {:?}",
142 message.format
143 ));
144 }
145 }
146 }
147
148 Ok(())
149 }
150
151 #[inline(always)]
152 pub(crate) fn write_sync(&mut self) {
153 self.inner
154 .stream
155 .write_msg(message::Sync)
156 .expect("BUG: Sync should not be too big for protocol");
157
158 self.inner.pending_ready_for_query_count += 1;
160 }
161
162 async fn get_or_prepare<'a>(
163 &mut self,
164 sql: &str,
165 parameters: &[PgTypeInfo],
166 store_to_cache: bool,
168 metadata: Option<Arc<PgStatementMetadata>>,
171 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
172 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
173 return Ok((*statement).clone());
174 }
175
176 let statement = prepare(self, sql, parameters, metadata).await?;
177
178 if store_to_cache && self.inner.cache_statement.is_enabled() {
179 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
180 self.inner.stream.write_msg(Close::Statement(id))?;
181 self.write_sync();
182
183 self.inner.stream.flush().await?;
184
185 self.wait_for_close_complete(1).await?;
186 self.recv_ready_for_query().await?;
187 }
188 }
189
190 Ok(statement)
191 }
192
193 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
194 &'c mut self,
195 query: &'q str,
196 arguments: Option<PgArguments>,
197 persistent: bool,
198 metadata_opt: Option<Arc<PgStatementMetadata>>,
199 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
200 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
201
202 self.wait_until_ready().await?;
204
205 let mut metadata: Arc<PgStatementMetadata>;
206
207 let format = if let Some(mut arguments) = arguments {
208 let num_params = u16::try_from(arguments.len()).map_err(|_| {
215 err_protocol!(
216 "PgConnection::run(): too many arguments for query: {}",
217 arguments.len()
218 )
219 })?;
220
221 let (statement, metadata_) = self
224 .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
225 .await?;
226
227 metadata = metadata_;
228
229 arguments.apply_patches(self, &metadata.parameters).await?;
231
232 self.wait_until_ready().await?;
234
235 self.inner.stream.write_msg(Bind {
237 portal: PortalId::UNNAMED,
238 statement,
239 formats: &[PgValueFormat::Binary],
240 num_params,
241 params: &arguments.buffer,
242 result_formats: &[PgValueFormat::Binary],
243 })?;
244
245 self.inner.stream.write_msg(message::Execute {
248 portal: PortalId::UNNAMED,
249 limit: 0,
252 })?;
253 self.inner
263 .stream
264 .write_msg(Close::Portal(PortalId::UNNAMED))?;
265
266 self.write_sync();
272
273 PgValueFormat::Binary
275 } else {
276 self.inner.stream.write_msg(Query(query))?;
278 self.inner.pending_ready_for_query_count += 1;
279
280 metadata = Arc::new(PgStatementMetadata::default());
282
283 PgValueFormat::Text
285 };
286
287 self.inner.stream.flush().await?;
288
289 Ok(try_stream! {
290 loop {
291 let message = self.inner.stream.recv().await?;
292
293 match message.format {
294 BackendMessageFormat::BindComplete
295 | BackendMessageFormat::ParseComplete
296 | BackendMessageFormat::ParameterDescription
297 | BackendMessageFormat::NoData
298 | BackendMessageFormat::CloseComplete
300 => {
301 }
303
304 BackendMessageFormat::CommandComplete => {
309 let cc: CommandComplete = message.decode()?;
311
312 let rows_affected = cc.rows_affected();
313 logger.increase_rows_affected(rows_affected);
314 r#yield!(Either::Left(PgQueryResult {
315 rows_affected,
316 }));
317 }
318
319 BackendMessageFormat::EmptyQueryResponse => {
320 }
322
323 BackendMessageFormat::PortalSuspended => {}
327
328 BackendMessageFormat::RowDescription => {
329 let (columns, column_names) = self
331 .handle_row_description(Some(message.decode()?), false)
332 .await?;
333
334 metadata = Arc::new(PgStatementMetadata {
335 column_names: Arc::new(column_names),
336 columns,
337 parameters: Vec::default(),
338 });
339 }
340
341 BackendMessageFormat::DataRow => {
342 logger.increment_rows_returned();
343
344 let data: DataRow = message.decode()?;
346 let row = PgRow {
347 data,
348 format,
349 metadata: Arc::clone(&metadata),
350 };
351
352 r#yield!(Either::Right(row));
353 }
354
355 BackendMessageFormat::ReadyForQuery => {
356 self.handle_ready_for_query(message)?;
358 break;
359 }
360
361 _ => {
362 return Err(err_protocol!(
363 "execute: unexpected message: {:?}",
364 message.format
365 ));
366 }
367 }
368 }
369
370 Ok(())
371 })
372 }
373}
374
375impl<'c> Executor<'c> for &'c mut PgConnection {
376 type Database = Postgres;
377
378 fn fetch_many<'e, 'q, E>(
379 self,
380 mut query: E,
381 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
382 where
383 'c: 'e,
384 E: Execute<'q, Self::Database>,
385 'q: 'e,
386 E: 'q,
387 {
388 let sql = query.sql();
389 #[allow(clippy::map_clone)]
391 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
392 let arguments = query.take_arguments().map_err(Error::Encode);
393 let persistent = query.persistent();
394
395 Box::pin(try_stream! {
396 let arguments = arguments?;
397 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
398
399 while let Some(v) = s.try_next().await? {
400 r#yield!(v);
401 }
402
403 Ok(())
404 })
405 }
406
407 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
408 where
409 'c: 'e,
410 E: Execute<'q, Self::Database>,
411 'q: 'e,
412 E: 'q,
413 {
414 let sql = query.sql();
415 #[allow(clippy::map_clone)]
417 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
418 let arguments = query.take_arguments().map_err(Error::Encode);
419 let persistent = query.persistent();
420
421 Box::pin(async move {
422 let arguments = arguments?;
423 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
424
425 let mut ret = None;
430 while let Some(result) = s.try_next().await? {
431 match result {
432 Either::Right(r) if ret.is_none() => ret = Some(r),
433 _ => {}
434 }
435 }
436 Ok(ret)
437 })
438 }
439
440 fn prepare_with<'e, 'q: 'e>(
441 self,
442 sql: &'q str,
443 parameters: &'e [PgTypeInfo],
444 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
445 where
446 'c: 'e,
447 {
448 Box::pin(async move {
449 self.wait_until_ready().await?;
450
451 let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
452
453 Ok(PgStatement {
454 sql: Cow::Borrowed(sql),
455 metadata,
456 })
457 })
458 }
459
460 fn describe<'e, 'q: 'e>(
461 self,
462 sql: &'q str,
463 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
464 where
465 'c: 'e,
466 {
467 Box::pin(async move {
468 self.wait_until_ready().await?;
469
470 let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
471
472 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
473
474 Ok(Describe {
475 columns: metadata.columns.clone(),
476 nullable,
477 parameters: Some(Either::Left(metadata.parameters.clone())),
478 })
479 })
480 }
481}