sqlx_postgres/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8    ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13    PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
29    let id = conn.inner.next_statement_id;
30    conn.inner.next_statement_id = id.next();
31
32    // build a list of type OIDs to send to the database in the PARSE command
33    // we have not yet started the query sequence, so we are *safe* to cleanly make
34    // additional queries here to get any missing OIDs
35
36    let mut param_types = Vec::with_capacity(parameters.len());
37
38    for ty in parameters {
39        param_types.push(conn.resolve_type_id(&ty.0).await?);
40    }
41
42    // flush and wait until we are re-ready
43    conn.wait_until_ready().await?;
44
45    // next we send the PARSE command to the server
46    conn.inner.stream.write_msg(Parse {
47        param_types: &param_types,
48        query: sql,
49        statement: id,
50    })?;
51
52    if metadata.is_none() {
53        // get the statement columns and parameters
54        conn.inner
55            .stream
56            .write_msg(message::Describe::Statement(id))?;
57    }
58
59    // we ask for the server to immediately send us the result of the PARSE command
60    conn.write_sync();
61    conn.inner.stream.flush().await?;
62
63    // indicates that the SQL query string is now successfully parsed and has semantic validity
64    conn.inner.stream.recv_expect::<ParseComplete>().await?;
65
66    let metadata = if let Some(metadata) = metadata {
67        // each SYNC produces one READY FOR QUERY
68        conn.recv_ready_for_query().await?;
69
70        // we already have metadata
71        metadata
72    } else {
73        let parameters = recv_desc_params(conn).await?;
74
75        let rows = recv_desc_rows(conn).await?;
76
77        // each SYNC produces one READY FOR QUERY
78        conn.recv_ready_for_query().await?;
79
80        let parameters = conn.handle_parameter_description(parameters).await?;
81
82        let (columns, column_names) = conn.handle_row_description(rows, true).await?;
83
84        // ensure that if we did fetch custom data, we wait until we are fully ready before
85        // continuing
86        conn.wait_until_ready().await?;
87
88        Arc::new(PgStatementMetadata {
89            parameters,
90            columns,
91            column_names: Arc::new(column_names),
92        })
93    };
94
95    Ok((id, metadata))
96}
97
98async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
99    conn.inner.stream.recv_expect().await
100}
101
102async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
103    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
104        // describes the rows that will be returned when the statement is eventually executed
105        message if message.format == BackendMessageFormat::RowDescription => {
106            Some(message.decode()?)
107        }
108
109        // no data would be returned if this statement was executed
110        message if message.format == BackendMessageFormat::NoData => None,
111
112        message => {
113            return Err(err_protocol!(
114                "expecting RowDescription or NoData but received {:?}",
115                message.format
116            ));
117        }
118    };
119
120    Ok(rows)
121}
122
123impl PgConnection {
124    // wait for CloseComplete to indicate a statement was closed
125    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
126        // we need to wait for the [CloseComplete] to be returned from the server
127        while count > 0 {
128            match self.inner.stream.recv().await? {
129                message if message.format == BackendMessageFormat::PortalSuspended => {
130                    // there was an open portal
131                    // this can happen if the last time a statement was used it was not fully executed
132                }
133
134                message if message.format == BackendMessageFormat::CloseComplete => {
135                    // successfully closed the statement (and freed up the server resources)
136                    count -= 1;
137                }
138
139                message => {
140                    return Err(err_protocol!(
141                        "expecting PortalSuspended or CloseComplete but received {:?}",
142                        message.format
143                    ));
144                }
145            }
146        }
147
148        Ok(())
149    }
150
151    #[inline(always)]
152    pub(crate) fn write_sync(&mut self) {
153        self.inner
154            .stream
155            .write_msg(message::Sync)
156            .expect("BUG: Sync should not be too big for protocol");
157
158        // all SYNC messages will return a ReadyForQuery
159        self.inner.pending_ready_for_query_count += 1;
160    }
161
162    async fn get_or_prepare<'a>(
163        &mut self,
164        sql: &str,
165        parameters: &[PgTypeInfo],
166        // should we store the result of this prepare to the cache
167        store_to_cache: bool,
168        // optional metadata that was provided by the user, this means they are reusing
169        // a statement object
170        metadata: Option<Arc<PgStatementMetadata>>,
171    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
172        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
173            return Ok((*statement).clone());
174        }
175
176        let statement = prepare(self, sql, parameters, metadata).await?;
177
178        if store_to_cache && self.inner.cache_statement.is_enabled() {
179            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
180                self.inner.stream.write_msg(Close::Statement(id))?;
181                self.write_sync();
182
183                self.inner.stream.flush().await?;
184
185                self.wait_for_close_complete(1).await?;
186                self.recv_ready_for_query().await?;
187            }
188        }
189
190        Ok(statement)
191    }
192
193    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
194        &'c mut self,
195        query: &'q str,
196        arguments: Option<PgArguments>,
197        persistent: bool,
198        metadata_opt: Option<Arc<PgStatementMetadata>>,
199    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
200        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
201
202        // before we continue, wait until we are "ready" to accept more queries
203        self.wait_until_ready().await?;
204
205        let mut metadata: Arc<PgStatementMetadata>;
206
207        let format = if let Some(mut arguments) = arguments {
208            // Check this before we write anything to the stream.
209            //
210            // Note: Postgres actually interprets this value as unsigned,
211            // making the max number of parameters 65535, not 32767
212            // https://github.com/launchbadge/sqlx/issues/3464
213            // https://www.postgresql.org/docs/current/limits.html
214            let num_params = u16::try_from(arguments.len()).map_err(|_| {
215                err_protocol!(
216                    "PgConnection::run(): too many arguments for query: {}",
217                    arguments.len()
218                )
219            })?;
220
221            // prepare the statement if this our first time executing it
222            // always return the statement ID here
223            let (statement, metadata_) = self
224                .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
225                .await?;
226
227            metadata = metadata_;
228
229            // patch holes created during encoding
230            arguments.apply_patches(self, &metadata.parameters).await?;
231
232            // consume messages till `ReadyForQuery` before bind and execute
233            self.wait_until_ready().await?;
234
235            // bind to attach the arguments to the statement and create a portal
236            self.inner.stream.write_msg(Bind {
237                portal: PortalId::UNNAMED,
238                statement,
239                formats: &[PgValueFormat::Binary],
240                num_params,
241                params: &arguments.buffer,
242                result_formats: &[PgValueFormat::Binary],
243            })?;
244
245            // executes the portal up to the passed limit
246            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
247            self.inner.stream.write_msg(message::Execute {
248                portal: PortalId::UNNAMED,
249                // Non-zero limits cause query plan pessimization by disabling parallel workers:
250                // https://github.com/launchbadge/sqlx/issues/3673
251                limit: 0,
252            })?;
253            // From https://www.postgresql.org/docs/current/protocol-flow.html:
254            //
255            // "An unnamed portal is destroyed at the end of the transaction, or as
256            // soon as the next Bind statement specifying the unnamed portal as
257            // destination is issued. (Note that a simple Query message also
258            // destroys the unnamed portal."
259
260            // we ask the database server to close the unnamed portal and free the associated resources
261            // earlier - after the execution of the current query.
262            self.inner
263                .stream
264                .write_msg(Close::Portal(PortalId::UNNAMED))?;
265
266            // finally, [Sync] asks postgres to process the messages that we sent and respond with
267            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
268            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
269            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
270            // termed batching might suit this.
271            self.write_sync();
272
273            // prepared statements are binary
274            PgValueFormat::Binary
275        } else {
276            // Query will trigger a ReadyForQuery
277            self.inner.stream.write_msg(Query(query))?;
278            self.inner.pending_ready_for_query_count += 1;
279
280            // metadata starts out as "nothing"
281            metadata = Arc::new(PgStatementMetadata::default());
282
283            // and unprepared statements are text
284            PgValueFormat::Text
285        };
286
287        self.inner.stream.flush().await?;
288
289        Ok(try_stream! {
290            loop {
291                let message = self.inner.stream.recv().await?;
292
293                match message.format {
294                    BackendMessageFormat::BindComplete
295                    | BackendMessageFormat::ParseComplete
296                    | BackendMessageFormat::ParameterDescription
297                    | BackendMessageFormat::NoData
298                    // unnamed portal has been closed
299                    | BackendMessageFormat::CloseComplete
300                    => {
301                        // harmless messages to ignore
302                    }
303
304                    // "Execute phase is always terminated by the appearance of
305                    // exactly one of these messages: CommandComplete,
306                    // EmptyQueryResponse (if the portal was created from an
307                    // empty query string), ErrorResponse, or PortalSuspended"
308                    BackendMessageFormat::CommandComplete => {
309                        // a SQL command completed normally
310                        let cc: CommandComplete = message.decode()?;
311
312                        let rows_affected = cc.rows_affected();
313                        logger.increase_rows_affected(rows_affected);
314                        r#yield!(Either::Left(PgQueryResult {
315                            rows_affected,
316                        }));
317                    }
318
319                    BackendMessageFormat::EmptyQueryResponse => {
320                        // empty query string passed to an unprepared execute
321                    }
322
323                    // Message::ErrorResponse is handled in self.stream.recv()
324
325                    // incomplete query execution has finished
326                    BackendMessageFormat::PortalSuspended => {}
327
328                    BackendMessageFormat::RowDescription => {
329                        // indicates that a *new* set of rows are about to be returned
330                        let (columns, column_names) = self
331                            .handle_row_description(Some(message.decode()?), false)
332                            .await?;
333
334                        metadata = Arc::new(PgStatementMetadata {
335                            column_names: Arc::new(column_names),
336                            columns,
337                            parameters: Vec::default(),
338                        });
339                    }
340
341                    BackendMessageFormat::DataRow => {
342                        logger.increment_rows_returned();
343
344                        // one of the set of rows returned by a SELECT, FETCH, etc query
345                        let data: DataRow = message.decode()?;
346                        let row = PgRow {
347                            data,
348                            format,
349                            metadata: Arc::clone(&metadata),
350                        };
351
352                        r#yield!(Either::Right(row));
353                    }
354
355                    BackendMessageFormat::ReadyForQuery => {
356                        // processing of the query string is complete
357                        self.handle_ready_for_query(message)?;
358                        break;
359                    }
360
361                    _ => {
362                        return Err(err_protocol!(
363                            "execute: unexpected message: {:?}",
364                            message.format
365                        ));
366                    }
367                }
368            }
369
370            Ok(())
371        })
372    }
373}
374
375impl<'c> Executor<'c> for &'c mut PgConnection {
376    type Database = Postgres;
377
378    fn fetch_many<'e, 'q, E>(
379        self,
380        mut query: E,
381    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
382    where
383        'c: 'e,
384        E: Execute<'q, Self::Database>,
385        'q: 'e,
386        E: 'q,
387    {
388        let sql = query.sql();
389        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
390        #[allow(clippy::map_clone)]
391        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
392        let arguments = query.take_arguments().map_err(Error::Encode);
393        let persistent = query.persistent();
394
395        Box::pin(try_stream! {
396            let arguments = arguments?;
397            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
398
399            while let Some(v) = s.try_next().await? {
400                r#yield!(v);
401            }
402
403            Ok(())
404        })
405    }
406
407    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
408    where
409        'c: 'e,
410        E: Execute<'q, Self::Database>,
411        'q: 'e,
412        E: 'q,
413    {
414        let sql = query.sql();
415        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
416        #[allow(clippy::map_clone)]
417        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
418        let arguments = query.take_arguments().map_err(Error::Encode);
419        let persistent = query.persistent();
420
421        Box::pin(async move {
422            let arguments = arguments?;
423            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
424
425            // With deferred constraints we need to check all responses as we
426            // could get a OK response (with uncommitted data), only to get an
427            // error response after (when the deferred constraint is actually
428            // checked).
429            let mut ret = None;
430            while let Some(result) = s.try_next().await? {
431                match result {
432                    Either::Right(r) if ret.is_none() => ret = Some(r),
433                    _ => {}
434                }
435            }
436            Ok(ret)
437        })
438    }
439
440    fn prepare_with<'e, 'q: 'e>(
441        self,
442        sql: &'q str,
443        parameters: &'e [PgTypeInfo],
444    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
445    where
446        'c: 'e,
447    {
448        Box::pin(async move {
449            self.wait_until_ready().await?;
450
451            let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
452
453            Ok(PgStatement {
454                sql: Cow::Borrowed(sql),
455                metadata,
456            })
457        })
458    }
459
460    fn describe<'e, 'q: 'e>(
461        self,
462        sql: &'q str,
463    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
464    where
465        'c: 'e,
466    {
467        Box::pin(async move {
468            self.wait_until_ready().await?;
469
470            let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
471
472            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
473
474            Ok(Describe {
475                columns: metadata.columns.clone(),
476                nullable,
477                parameters: Some(Either::Left(metadata.parameters.clone())),
478            })
479        })
480    }
481}