sqlx_postgres/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8    ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13    PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28    persistent: bool,
29) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
30    let id = if persistent {
31        let id = conn.inner.next_statement_id;
32        conn.inner.next_statement_id = id.next();
33        id
34    } else {
35        StatementId::UNNAMED
36    };
37
38    // build a list of type OIDs to send to the database in the PARSE command
39    // we have not yet started the query sequence, so we are *safe* to cleanly make
40    // additional queries here to get any missing OIDs
41
42    let mut param_types = Vec::with_capacity(parameters.len());
43
44    for ty in parameters {
45        param_types.push(conn.resolve_type_id(&ty.0).await?);
46    }
47
48    // flush and wait until we are re-ready
49    conn.wait_until_ready().await?;
50
51    // next we send the PARSE command to the server
52    conn.inner.stream.write_msg(Parse {
53        param_types: &param_types,
54        query: sql,
55        statement: id,
56    })?;
57
58    if metadata.is_none() {
59        // get the statement columns and parameters
60        conn.inner
61            .stream
62            .write_msg(message::Describe::Statement(id))?;
63    }
64
65    // we ask for the server to immediately send us the result of the PARSE command
66    conn.write_sync();
67    conn.inner.stream.flush().await?;
68
69    // indicates that the SQL query string is now successfully parsed and has semantic validity
70    conn.inner.stream.recv_expect::<ParseComplete>().await?;
71
72    let metadata = if let Some(metadata) = metadata {
73        // each SYNC produces one READY FOR QUERY
74        conn.recv_ready_for_query().await?;
75
76        // we already have metadata
77        metadata
78    } else {
79        let parameters = recv_desc_params(conn).await?;
80
81        let rows = recv_desc_rows(conn).await?;
82
83        // each SYNC produces one READY FOR QUERY
84        conn.recv_ready_for_query().await?;
85
86        let parameters = conn.handle_parameter_description(parameters).await?;
87
88        let (columns, column_names) = conn.handle_row_description(rows, true).await?;
89
90        // ensure that if we did fetch custom data, we wait until we are fully ready before
91        // continuing
92        conn.wait_until_ready().await?;
93
94        Arc::new(PgStatementMetadata {
95            parameters,
96            columns,
97            column_names: Arc::new(column_names),
98        })
99    };
100
101    Ok((id, metadata))
102}
103
104async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
105    conn.inner.stream.recv_expect().await
106}
107
108async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
109    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
110        // describes the rows that will be returned when the statement is eventually executed
111        message if message.format == BackendMessageFormat::RowDescription => {
112            Some(message.decode()?)
113        }
114
115        // no data would be returned if this statement was executed
116        message if message.format == BackendMessageFormat::NoData => None,
117
118        message => {
119            return Err(err_protocol!(
120                "expecting RowDescription or NoData but received {:?}",
121                message.format
122            ));
123        }
124    };
125
126    Ok(rows)
127}
128
129impl PgConnection {
130    // wait for CloseComplete to indicate a statement was closed
131    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
132        // we need to wait for the [CloseComplete] to be returned from the server
133        while count > 0 {
134            match self.inner.stream.recv().await? {
135                message if message.format == BackendMessageFormat::PortalSuspended => {
136                    // there was an open portal
137                    // this can happen if the last time a statement was used it was not fully executed
138                }
139
140                message if message.format == BackendMessageFormat::CloseComplete => {
141                    // successfully closed the statement (and freed up the server resources)
142                    count -= 1;
143                }
144
145                message => {
146                    return Err(err_protocol!(
147                        "expecting PortalSuspended or CloseComplete but received {:?}",
148                        message.format
149                    ));
150                }
151            }
152        }
153
154        Ok(())
155    }
156
157    #[inline(always)]
158    pub(crate) fn write_sync(&mut self) {
159        self.inner
160            .stream
161            .write_msg(message::Sync)
162            .expect("BUG: Sync should not be too big for protocol");
163
164        // all SYNC messages will return a ReadyForQuery
165        self.inner.pending_ready_for_query_count += 1;
166    }
167
168    async fn get_or_prepare<'a>(
169        &mut self,
170        sql: &str,
171        parameters: &[PgTypeInfo],
172        persistent: bool,
173        // optional metadata that was provided by the user, this means they are reusing
174        // a statement object
175        metadata: Option<Arc<PgStatementMetadata>>,
176    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
177        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
178            return Ok((*statement).clone());
179        }
180
181        let statement = prepare(self, sql, parameters, metadata, persistent).await?;
182
183        if persistent && self.inner.cache_statement.is_enabled() {
184            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
185                self.inner.stream.write_msg(Close::Statement(id))?;
186                self.write_sync();
187
188                self.inner.stream.flush().await?;
189
190                self.wait_for_close_complete(1).await?;
191                self.recv_ready_for_query().await?;
192            }
193        }
194
195        Ok(statement)
196    }
197
198    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
199        &'c mut self,
200        query: &'q str,
201        arguments: Option<PgArguments>,
202        persistent: bool,
203        metadata_opt: Option<Arc<PgStatementMetadata>>,
204    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
205        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
206
207        // before we continue, wait until we are "ready" to accept more queries
208        self.wait_until_ready().await?;
209
210        let mut metadata: Arc<PgStatementMetadata>;
211
212        let format = if let Some(mut arguments) = arguments {
213            // Check this before we write anything to the stream.
214            //
215            // Note: Postgres actually interprets this value as unsigned,
216            // making the max number of parameters 65535, not 32767
217            // https://github.com/launchbadge/sqlx/issues/3464
218            // https://www.postgresql.org/docs/current/limits.html
219            let num_params = u16::try_from(arguments.len()).map_err(|_| {
220                err_protocol!(
221                    "PgConnection::run(): too many arguments for query: {}",
222                    arguments.len()
223                )
224            })?;
225
226            // prepare the statement if this our first time executing it
227            // always return the statement ID here
228            let (statement, metadata_) = self
229                .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
230                .await?;
231
232            metadata = metadata_;
233
234            // patch holes created during encoding
235            arguments.apply_patches(self, &metadata.parameters).await?;
236
237            // consume messages till `ReadyForQuery` before bind and execute
238            self.wait_until_ready().await?;
239
240            // bind to attach the arguments to the statement and create a portal
241            self.inner.stream.write_msg(Bind {
242                portal: PortalId::UNNAMED,
243                statement,
244                formats: &[PgValueFormat::Binary],
245                num_params,
246                params: &arguments.buffer,
247                result_formats: &[PgValueFormat::Binary],
248            })?;
249
250            // executes the portal up to the passed limit
251            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
252            self.inner.stream.write_msg(message::Execute {
253                portal: PortalId::UNNAMED,
254                // Non-zero limits cause query plan pessimization by disabling parallel workers:
255                // https://github.com/launchbadge/sqlx/issues/3673
256                limit: 0,
257            })?;
258            // From https://www.postgresql.org/docs/current/protocol-flow.html:
259            //
260            // "An unnamed portal is destroyed at the end of the transaction, or as
261            // soon as the next Bind statement specifying the unnamed portal as
262            // destination is issued. (Note that a simple Query message also
263            // destroys the unnamed portal."
264
265            // we ask the database server to close the unnamed portal and free the associated resources
266            // earlier - after the execution of the current query.
267            self.inner
268                .stream
269                .write_msg(Close::Portal(PortalId::UNNAMED))?;
270
271            // finally, [Sync] asks postgres to process the messages that we sent and respond with
272            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
273            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
274            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
275            // termed batching might suit this.
276            self.write_sync();
277
278            // prepared statements are binary
279            PgValueFormat::Binary
280        } else {
281            // Query will trigger a ReadyForQuery
282            self.inner.stream.write_msg(Query(query))?;
283            self.inner.pending_ready_for_query_count += 1;
284
285            // metadata starts out as "nothing"
286            metadata = Arc::new(PgStatementMetadata::default());
287
288            // and unprepared statements are text
289            PgValueFormat::Text
290        };
291
292        self.inner.stream.flush().await?;
293
294        Ok(try_stream! {
295            loop {
296                let message = self.inner.stream.recv().await?;
297
298                match message.format {
299                    BackendMessageFormat::BindComplete
300                    | BackendMessageFormat::ParseComplete
301                    | BackendMessageFormat::ParameterDescription
302                    | BackendMessageFormat::NoData
303                    // unnamed portal has been closed
304                    | BackendMessageFormat::CloseComplete
305                    => {
306                        // harmless messages to ignore
307                    }
308
309                    // "Execute phase is always terminated by the appearance of
310                    // exactly one of these messages: CommandComplete,
311                    // EmptyQueryResponse (if the portal was created from an
312                    // empty query string), ErrorResponse, or PortalSuspended"
313                    BackendMessageFormat::CommandComplete => {
314                        // a SQL command completed normally
315                        let cc: CommandComplete = message.decode()?;
316
317                        let rows_affected = cc.rows_affected();
318                        logger.increase_rows_affected(rows_affected);
319                        r#yield!(Either::Left(PgQueryResult {
320                            rows_affected,
321                        }));
322                    }
323
324                    BackendMessageFormat::EmptyQueryResponse => {
325                        // empty query string passed to an unprepared execute
326                    }
327
328                    // Message::ErrorResponse is handled in self.stream.recv()
329
330                    // incomplete query execution has finished
331                    BackendMessageFormat::PortalSuspended => {}
332
333                    BackendMessageFormat::RowDescription => {
334                        // indicates that a *new* set of rows are about to be returned
335                        let (columns, column_names) = self
336                            .handle_row_description(Some(message.decode()?), false)
337                            .await?;
338
339                        metadata = Arc::new(PgStatementMetadata {
340                            column_names: Arc::new(column_names),
341                            columns,
342                            parameters: Vec::default(),
343                        });
344                    }
345
346                    BackendMessageFormat::DataRow => {
347                        logger.increment_rows_returned();
348
349                        // one of the set of rows returned by a SELECT, FETCH, etc query
350                        let data: DataRow = message.decode()?;
351                        let row = PgRow {
352                            data,
353                            format,
354                            metadata: Arc::clone(&metadata),
355                        };
356
357                        r#yield!(Either::Right(row));
358                    }
359
360                    BackendMessageFormat::ReadyForQuery => {
361                        // processing of the query string is complete
362                        self.handle_ready_for_query(message)?;
363                        break;
364                    }
365
366                    _ => {
367                        return Err(err_protocol!(
368                            "execute: unexpected message: {:?}",
369                            message.format
370                        ));
371                    }
372                }
373            }
374
375            Ok(())
376        })
377    }
378}
379
380impl<'c> Executor<'c> for &'c mut PgConnection {
381    type Database = Postgres;
382
383    fn fetch_many<'e, 'q, E>(
384        self,
385        mut query: E,
386    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
387    where
388        'c: 'e,
389        E: Execute<'q, Self::Database>,
390        'q: 'e,
391        E: 'q,
392    {
393        let sql = query.sql();
394        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
395        #[allow(clippy::map_clone)]
396        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
397        let arguments = query.take_arguments().map_err(Error::Encode);
398        let persistent = query.persistent();
399
400        Box::pin(try_stream! {
401            let arguments = arguments?;
402            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
403
404            while let Some(v) = s.try_next().await? {
405                r#yield!(v);
406            }
407
408            Ok(())
409        })
410    }
411
412    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
413    where
414        'c: 'e,
415        E: Execute<'q, Self::Database>,
416        'q: 'e,
417        E: 'q,
418    {
419        let sql = query.sql();
420        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
421        #[allow(clippy::map_clone)]
422        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
423        let arguments = query.take_arguments().map_err(Error::Encode);
424        let persistent = query.persistent();
425
426        Box::pin(async move {
427            let arguments = arguments?;
428            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
429
430            // With deferred constraints we need to check all responses as we
431            // could get a OK response (with uncommitted data), only to get an
432            // error response after (when the deferred constraint is actually
433            // checked).
434            let mut ret = None;
435            while let Some(result) = s.try_next().await? {
436                match result {
437                    Either::Right(r) if ret.is_none() => ret = Some(r),
438                    _ => {}
439                }
440            }
441            Ok(ret)
442        })
443    }
444
445    fn prepare_with<'e, 'q: 'e>(
446        self,
447        sql: &'q str,
448        parameters: &'e [PgTypeInfo],
449    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
450    where
451        'c: 'e,
452    {
453        Box::pin(async move {
454            self.wait_until_ready().await?;
455
456            let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
457
458            Ok(PgStatement {
459                sql: Cow::Borrowed(sql),
460                metadata,
461            })
462        })
463    }
464
465    fn describe<'e, 'q: 'e>(
466        self,
467        sql: &'q str,
468    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
469    where
470        'c: 'e,
471    {
472        Box::pin(async move {
473            self.wait_until_ready().await?;
474
475            let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
476
477            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
478
479            Ok(Describe {
480                columns: metadata.columns.clone(),
481                nullable,
482                parameters: Some(Either::Left(metadata.parameters.clone())),
483            })
484        })
485    }
486}