Skip to main content

sqlx_postgres/connection/
executor.rs

1use crate::error::Error;
2use crate::executor::{Execute, Executor};
3use crate::io::{PortalId, StatementId};
4use crate::logger::QueryLogger;
5use crate::message::{
6    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
7    ParseComplete, RowDescription,
8};
9use crate::statement::PgStatementMetadata;
10use crate::{
11    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
12    PgValueFormat, Postgres,
13};
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::TryStreamExt;
18use sqlx_core::arguments::Arguments;
19use sqlx_core::sql_str::SqlStr;
20use sqlx_core::Either;
21use std::{pin::pin, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    arg_types: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28    persistent: bool,
29    resolve_column_origin: bool,
30) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
31    let id = if persistent {
32        let id = conn.inner.next_statement_id;
33        conn.inner.next_statement_id = id.next();
34        id
35    } else {
36        StatementId::UNNAMED
37    };
38
39    // build a list of type OIDs to send to the database in the PARSE command
40    // we have not yet started the query sequence, so we are *safe* to cleanly make
41    // additional queries here to get any missing OIDs
42    let param_types = conn.resolve_types(arg_types).await?;
43
44    // flush and wait until we are re-ready
45    conn.wait_until_ready().await?;
46
47    // next we send the PARSE command to the server
48    conn.inner.stream.write_msg(Parse {
49        param_types: &param_types,
50        query: sql,
51        statement: id,
52    })?;
53
54    if metadata.is_none() {
55        // get the statement columns and parameters
56        conn.inner
57            .stream
58            .write_msg(message::Describe::Statement(id))?;
59    }
60
61    // we ask for the server to immediately send us the result of the PARSE command
62    conn.write_sync();
63    conn.inner.stream.flush().await?;
64
65    // indicates that the SQL query string is now successfully parsed and has semantic validity
66    conn.inner.stream.recv_expect::<ParseComplete>().await?;
67
68    let metadata = if let Some(metadata) = metadata {
69        // each SYNC produces one READY FOR QUERY
70        conn.recv_ready_for_query().await?;
71
72        // we already have metadata
73        metadata
74    } else {
75        let parameters = recv_desc_params(conn).await?;
76
77        let row_desc = recv_desc_rows(conn).await?;
78
79        // each SYNC produces one READY FOR QUERY
80        conn.recv_ready_for_query().await?;
81
82        let metadata = conn
83            .resolve_statement_metadata::<true>(Some(parameters), row_desc, resolve_column_origin)
84            .await?;
85
86        // ensure that if we did fetch custom data, we wait until we are fully ready before
87        // continuing
88        conn.wait_until_ready().await?;
89
90        metadata
91    };
92
93    Ok((id, metadata))
94}
95
96async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
97    conn.inner.stream.recv_expect().await
98}
99
100async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
101    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
102        // describes the rows that will be returned when the statement is eventually executed
103        message if message.format == BackendMessageFormat::RowDescription => {
104            Some(message.decode()?)
105        }
106
107        // no data would be returned if this statement was executed
108        message if message.format == BackendMessageFormat::NoData => None,
109
110        message => {
111            return Err(err_protocol!(
112                "expecting RowDescription or NoData but received {:?}",
113                message.format
114            ));
115        }
116    };
117
118    Ok(rows)
119}
120
121impl PgConnection {
122    // wait for CloseComplete to indicate a statement was closed
123    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
124        // we need to wait for the [CloseComplete] to be returned from the server
125        while count > 0 {
126            match self.inner.stream.recv().await? {
127                message if message.format == BackendMessageFormat::PortalSuspended => {
128                    // there was an open portal
129                    // this can happen if the last time a statement was used it was not fully executed
130                }
131
132                message if message.format == BackendMessageFormat::CloseComplete => {
133                    // successfully closed the statement (and freed up the server resources)
134                    count -= 1;
135                }
136
137                message => {
138                    return Err(err_protocol!(
139                        "expecting PortalSuspended or CloseComplete but received {:?}",
140                        message.format
141                    ));
142                }
143            }
144        }
145
146        Ok(())
147    }
148
149    #[inline(always)]
150    pub(crate) fn write_sync(&mut self) {
151        self.inner
152            .stream
153            .write_msg(message::Sync)
154            .expect("BUG: Sync should not be too big for protocol");
155
156        // all SYNC messages will return a ReadyForQuery
157        self.inner.pending_ready_for_query_count += 1;
158    }
159
160    async fn get_or_prepare(
161        &mut self,
162        sql: &str,
163        parameters: &[PgTypeInfo],
164        persistent: bool,
165        // optional metadata that was provided by the user, this means they are reusing
166        // a statement object
167        metadata: Option<Arc<PgStatementMetadata>>,
168        resolve_column_origin: bool,
169    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
170        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
171            return Ok((*statement).clone());
172        }
173
174        let statement = prepare(
175            self,
176            sql,
177            parameters,
178            metadata,
179            persistent,
180            resolve_column_origin,
181        )
182        .await?;
183
184        if persistent && self.inner.cache_statement.is_enabled() {
185            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
186                self.inner.stream.write_msg(Close::Statement(id))?;
187                self.write_sync();
188
189                self.inner.stream.flush().await?;
190
191                self.wait_for_close_complete(1).await?;
192                self.recv_ready_for_query().await?;
193            }
194        }
195
196        Ok(statement)
197    }
198
199    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
200        &'c mut self,
201        query: SqlStr,
202        arguments: Option<PgArguments>,
203        persistent: bool,
204        metadata_opt: Option<Arc<PgStatementMetadata>>,
205    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
206        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
207        let sql = logger.sql().as_str();
208
209        // before we continue, wait until we are "ready" to accept more queries
210        self.wait_until_ready().await?;
211
212        let mut metadata: Arc<PgStatementMetadata>;
213
214        let format = if let Some(mut arguments) = arguments {
215            // Check this before we write anything to the stream.
216            //
217            // Note: Postgres actually interprets this value as unsigned,
218            // making the max number of parameters 65535, not 32767
219            // https://github.com/launchbadge/sqlx/issues/3464
220            // https://www.postgresql.org/docs/current/limits.html
221            let num_params = u16::try_from(arguments.len()).map_err(|_| {
222                err_protocol!(
223                    "PgConnection::run(): too many arguments for query: {}",
224                    arguments.len()
225                )
226            })?;
227
228            // prepare the statement if this our first time executing it
229            // always return the statement ID here
230            let (statement, metadata_) = self
231                .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
232                .await?;
233
234            metadata = metadata_;
235
236            // patch holes created during encoding
237            arguments.apply_patches(self, &metadata.parameters).await?;
238
239            // consume messages till `ReadyForQuery` before bind and execute
240            self.wait_until_ready().await?;
241
242            // bind to attach the arguments to the statement and create a portal
243            self.inner.stream.write_msg(Bind {
244                portal: PortalId::UNNAMED,
245                statement,
246                formats: &[PgValueFormat::Binary],
247                num_params,
248                params: &arguments.buffer,
249                result_formats: &[PgValueFormat::Binary],
250            })?;
251
252            // executes the portal up to the passed limit
253            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
254            self.inner.stream.write_msg(message::Execute {
255                portal: PortalId::UNNAMED,
256                // Non-zero limits cause query plan pessimization by disabling parallel workers:
257                // https://github.com/launchbadge/sqlx/issues/3673
258                limit: 0,
259            })?;
260            // From https://www.postgresql.org/docs/current/protocol-flow.html:
261            //
262            // "An unnamed portal is destroyed at the end of the transaction, or as
263            // soon as the next Bind statement specifying the unnamed portal as
264            // destination is issued. (Note that a simple Query message also
265            // destroys the unnamed portal."
266
267            // we ask the database server to close the unnamed portal and free the associated resources
268            // earlier - after the execution of the current query.
269            self.inner
270                .stream
271                .write_msg(Close::Portal(PortalId::UNNAMED))?;
272
273            // finally, [Sync] asks postgres to process the messages that we sent and respond with
274            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
275            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
276            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
277            // termed batching might suit this.
278            self.write_sync();
279
280            // prepared statements are binary
281            PgValueFormat::Binary
282        } else {
283            // Query will trigger a ReadyForQuery
284            self.queue_simple_query(sql)?;
285
286            // metadata starts out as "nothing"
287            metadata = Arc::new(PgStatementMetadata::default());
288
289            // and unprepared statements are text
290            PgValueFormat::Text
291        };
292
293        self.inner.stream.flush().await?;
294
295        Ok(try_stream! {
296            loop {
297                let message = self.inner.stream.recv().await?;
298
299                match message.format {
300                    BackendMessageFormat::BindComplete
301                    | BackendMessageFormat::ParseComplete
302                    | BackendMessageFormat::ParameterDescription
303                    | BackendMessageFormat::NoData
304                    // unnamed portal has been closed
305                    | BackendMessageFormat::CloseComplete
306                    => {
307                        // harmless messages to ignore
308                    }
309
310                    // "Execute phase is always terminated by the appearance of
311                    // exactly one of these messages: CommandComplete,
312                    // EmptyQueryResponse (if the portal was created from an
313                    // empty query string), ErrorResponse, or PortalSuspended"
314                    BackendMessageFormat::CommandComplete => {
315                        // a SQL command completed normally
316                        let cc: CommandComplete = message.decode()?;
317
318                        let rows_affected = cc.rows_affected();
319                        logger.increase_rows_affected(rows_affected);
320                        r#yield!(Either::Left(PgQueryResult {
321                            rows_affected,
322                        }));
323                    }
324
325                    BackendMessageFormat::EmptyQueryResponse => {
326                        // empty query string passed to an unprepared execute
327                    }
328
329                    // Message::ErrorResponse is handled in self.stream.recv()
330
331                    // incomplete query execution has finished
332                    BackendMessageFormat::PortalSuspended => {}
333
334                    // indicates that a *new* set of rows are about to be returned
335                    BackendMessageFormat::RowDescription => {
336                        let new_metadata = self.resolve_statement_metadata::<false>(
337                            None,
338                            Some(message.decode()?),
339                            false,
340                        ).await?;
341
342                        metadata = new_metadata;
343                    }
344
345                    BackendMessageFormat::DataRow => {
346                        logger.increment_rows_returned();
347
348                        // one of the set of rows returned by a SELECT, FETCH, etc query
349                        let data: DataRow = message.decode()?;
350                        let row = PgRow {
351                            data,
352                            format,
353                            metadata: Arc::clone(&metadata),
354                        };
355
356                        r#yield!(Either::Right(row));
357                    }
358
359                    BackendMessageFormat::ReadyForQuery => {
360                        // processing of the query string is complete
361                        self.handle_ready_for_query(message)?;
362                        break;
363                    }
364
365                    _ => {
366                        return Err(err_protocol!(
367                            "execute: unexpected message: {:?}",
368                            message.format
369                        ));
370                    }
371                }
372            }
373
374            Ok(())
375        })
376    }
377}
378
379impl<'c> Executor<'c> for &'c mut PgConnection {
380    type Database = Postgres;
381
382    fn fetch_many<'e, 'q, E>(
383        self,
384        mut query: E,
385    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
386    where
387        'c: 'e,
388        E: Execute<'q, Self::Database>,
389        'q: 'e,
390        E: 'q,
391    {
392        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
393        #[allow(clippy::map_clone)]
394        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
395        let arguments = query.take_arguments().map_err(Error::Encode);
396        let persistent = query.persistent();
397        let sql = query.sql();
398
399        Box::pin(try_stream! {
400            let arguments = arguments?;
401            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
402
403            while let Some(v) = s.try_next().await? {
404                r#yield!(v);
405            }
406
407            Ok(())
408        })
409    }
410
411    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
412    where
413        'c: 'e,
414        E: Execute<'q, Self::Database>,
415        'q: 'e,
416        E: 'q,
417    {
418        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
419        #[allow(clippy::map_clone)]
420        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
421        let arguments = query.take_arguments().map_err(Error::Encode);
422        let persistent = query.persistent();
423
424        Box::pin(async move {
425            let sql = query.sql();
426            let arguments = arguments?;
427            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
428
429            // With deferred constraints we need to check all responses as we
430            // could get a OK response (with uncommitted data), only to get an
431            // error response after (when the deferred constraint is actually
432            // checked).
433            let mut ret = None;
434            while let Some(result) = s.try_next().await? {
435                match result {
436                    Either::Right(r) if ret.is_none() => ret = Some(r),
437                    _ => {}
438                }
439            }
440            Ok(ret)
441        })
442    }
443
444    fn prepare_with<'e>(
445        self,
446        sql: SqlStr,
447        parameters: &'e [PgTypeInfo],
448    ) -> BoxFuture<'e, Result<PgStatement, Error>>
449    where
450        'c: 'e,
451    {
452        Box::pin(async move {
453            self.wait_until_ready().await?;
454
455            let (_, metadata) = self
456                .get_or_prepare(sql.as_str(), parameters, true, None, true)
457                .await?;
458
459            Ok(PgStatement { sql, metadata })
460        })
461    }
462
463    #[cfg(feature = "offline")]
464    fn describe<'e>(
465        self,
466        sql: SqlStr,
467    ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
468    where
469        'c: 'e,
470    {
471        Box::pin(async move {
472            self.wait_until_ready().await?;
473
474            let (stmt_id, metadata) = self
475                .get_or_prepare(sql.as_str(), &[], true, None, true)
476                .await?;
477
478            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
479
480            Ok(crate::describe::Describe {
481                columns: metadata.columns.clone(),
482                nullable,
483                parameters: Some(Either::Left(metadata.parameters.clone())),
484            })
485        })
486    }
487}