sqlx_core_oldapi/postgres/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::logger::QueryLogger;
5use crate::postgres::message::{
6    self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
7    RowDescription,
8};
9use crate::postgres::statement::PgStatementMetadata;
10use crate::postgres::type_info::PgType;
11use crate::postgres::types::Oid;
12use crate::postgres::{
13    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
14    PgValueFormat, Postgres,
15};
16use either::Either;
17use futures_core::future::BoxFuture;
18use futures_core::stream::BoxStream;
19use futures_core::Stream;
20use futures_util::{pin_mut, TryStreamExt};
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
29    let id = conn.next_statement_id;
30    conn.next_statement_id.incr_one();
31
32    // build a list of type OIDs to send to the database in the PARSE command
33    // we have not yet started the query sequence, so we are *safe* to cleanly make
34    // additional queries here to get any missing OIDs
35
36    let mut param_types = Vec::with_capacity(parameters.len());
37
38    for ty in parameters {
39        param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
40            conn.fetch_type_id_by_name(name).await?
41        } else {
42            ty.0.oid()
43        });
44    }
45
46    // flush and wait until we are re-ready
47    conn.wait_until_ready().await?;
48
49    // next we send the PARSE command to the server
50    conn.stream.write(Parse {
51        param_types: &*param_types,
52        query: sql,
53        statement: id,
54    });
55
56    if metadata.is_none() {
57        // get the statement columns and parameters
58        conn.stream.write(message::Describe::Statement(id));
59    }
60
61    // we ask for the server to immediately send us the result of the PARSE command
62    conn.write_sync();
63    conn.stream.flush().await?;
64
65    // indicates that the SQL query string is now successfully parsed and has semantic validity
66    let _: () = conn
67        .stream
68        .recv_expect(MessageFormat::ParseComplete)
69        .await?;
70
71    let metadata = if let Some(metadata) = metadata {
72        // each SYNC produces one READY FOR QUERY
73        conn.recv_ready_for_query().await?;
74
75        // we already have metadata
76        metadata
77    } else {
78        let parameters = recv_desc_params(conn).await?;
79
80        let rows = recv_desc_rows(conn).await?;
81
82        // each SYNC produces one READY FOR QUERY
83        conn.recv_ready_for_query().await?;
84
85        let parameters = conn.handle_parameter_description(parameters).await?;
86
87        let (columns, column_names) = conn.handle_row_description(rows, true).await?;
88
89        // ensure that if we did fetch custom data, we wait until we are fully ready before
90        // continuing
91        conn.wait_until_ready().await?;
92
93        Arc::new(PgStatementMetadata {
94            parameters,
95            columns,
96            column_names,
97        })
98    };
99
100    Ok((id, metadata))
101}
102
103async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
104    conn.stream
105        .recv_expect(MessageFormat::ParameterDescription)
106        .await
107}
108
109async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
110    let rows: Option<RowDescription> = match conn.stream.recv().await? {
111        // describes the rows that will be returned when the statement is eventually executed
112        message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
113
114        // no data would be returned if this statement was executed
115        message if message.format == MessageFormat::NoData => None,
116
117        message => {
118            return Err(err_protocol!(
119                "expecting RowDescription or NoData but received {:?}",
120                message.format
121            ));
122        }
123    };
124
125    Ok(rows)
126}
127
128impl PgConnection {
129    // wait for CloseComplete to indicate a statement was closed
130    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
131        // we need to wait for the [CloseComplete] to be returned from the server
132        while count > 0 {
133            match self.stream.recv().await? {
134                message if message.format == MessageFormat::PortalSuspended => {
135                    // there was an open portal
136                    // this can happen if the last time a statement was used it was not fully executed
137                }
138
139                message if message.format == MessageFormat::CloseComplete => {
140                    // successfully closed the statement (and freed up the server resources)
141                    count -= 1;
142                }
143
144                message => {
145                    return Err(err_protocol!(
146                        "expecting PortalSuspended or CloseComplete but received {:?}",
147                        message.format
148                    ));
149                }
150            }
151        }
152
153        Ok(())
154    }
155
156    pub(crate) fn write_sync(&mut self) {
157        self.stream.write(message::Sync);
158
159        // all SYNC messages will return a ReadyForQuery
160        self.pending_ready_for_query_count += 1;
161    }
162
163    async fn get_or_prepare<'a>(
164        &mut self,
165        sql: &str,
166        parameters: &[PgTypeInfo],
167        // should we store the result of this prepare to the cache
168        store_to_cache: bool,
169        // optional metadata that was provided by the user, this means they are reusing
170        // a statement object
171        metadata: Option<Arc<PgStatementMetadata>>,
172    ) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
173        if let Some(statement) = self.cache_statement.get_mut(sql) {
174            return Ok((*statement).clone());
175        }
176
177        let statement = prepare(self, sql, parameters, metadata).await?;
178
179        if store_to_cache && self.cache_statement.is_enabled() {
180            if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
181                self.stream.write(Close::Statement(id));
182                self.write_sync();
183
184                self.stream.flush().await?;
185
186                self.wait_for_close_complete(1).await?;
187                self.recv_ready_for_query().await?;
188            }
189        }
190
191        Ok(statement)
192    }
193
194    async fn run<'e, 'c: 'e, 'q: 'e>(
195        &'c mut self,
196        query: &'q str,
197        arguments: Option<PgArguments>,
198        limit: u8,
199        persistent: bool,
200        metadata_opt: Option<Arc<PgStatementMetadata>>,
201    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
202        let mut logger = QueryLogger::new(query, self.log_settings.clone());
203
204        // before we continue, wait until we are "ready" to accept more queries
205        self.wait_until_ready().await?;
206
207        let mut metadata: Arc<PgStatementMetadata>;
208
209        let format = if let Some(mut arguments) = arguments {
210            // prepare the statement if this our first time executing it
211            // always return the statement ID here
212            let (statement, metadata_) = self
213                .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
214                .await?;
215
216            metadata = metadata_;
217
218            // patch holes created during encoding
219            arguments.apply_patches(self, &metadata.parameters).await?;
220
221            // consume messages till `ReadyForQuery` before bind and execute
222            self.wait_until_ready().await?;
223
224            // bind to attach the arguments to the statement and create a portal
225            self.stream.write(Bind {
226                portal: None,
227                statement,
228                formats: &[PgValueFormat::Binary],
229                num_params: arguments.types.len() as i16,
230                params: &*arguments.buffer,
231                result_formats: &[PgValueFormat::Binary],
232            });
233
234            // executes the portal up to the passed limit
235            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
236            self.stream.write(message::Execute {
237                portal: None,
238                limit: limit.into(),
239            });
240            // From https://www.postgresql.org/docs/current/protocol-flow.html:
241            //
242            // "An unnamed portal is destroyed at the end of the transaction, or as
243            // soon as the next Bind statement specifying the unnamed portal as
244            // destination is issued. (Note that a simple Query message also
245            // destroys the unnamed portal."
246
247            // we ask the database server to close the unnamed portal and free the associated resources
248            // earlier - after the execution of the current query.
249            self.stream.write(message::Close::Portal(None));
250
251            // finally, [Sync] asks postgres to process the messages that we sent and respond with
252            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
253            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
254            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
255            // termed batching might suit this.
256            self.write_sync();
257
258            // prepared statements are binary
259            PgValueFormat::Binary
260        } else {
261            // Query will trigger a ReadyForQuery
262            self.stream.write(Query(query));
263            self.pending_ready_for_query_count += 1;
264
265            // metadata starts out as "nothing"
266            metadata = Arc::new(PgStatementMetadata::default());
267
268            // and unprepared statements are text
269            PgValueFormat::Text
270        };
271
272        self.stream.flush().await?;
273
274        Ok(try_stream! {
275            loop {
276                let message = self.stream.recv().await?;
277
278                match message.format {
279                    MessageFormat::BindComplete
280                    | MessageFormat::ParseComplete
281                    | MessageFormat::ParameterDescription
282                    | MessageFormat::NoData
283                    // unnamed portal has been closed
284                    | MessageFormat::CloseComplete
285                    => {
286                        // harmless messages to ignore
287                    }
288
289                    // "Execute phase is always terminated by the appearance of
290                    // exactly one of these messages: CommandComplete,
291                    // EmptyQueryResponse (if the portal was created from an
292                    // empty query string), ErrorResponse, or PortalSuspended"
293                    MessageFormat::CommandComplete => {
294                        // a SQL command completed normally
295                        let cc: CommandComplete = message.decode()?;
296
297                        let rows_affected = cc.rows_affected();
298                        logger.increase_rows_affected(rows_affected);
299                        r#yield!(Either::Left(PgQueryResult {
300                            rows_affected,
301                        }));
302                    }
303
304                    MessageFormat::EmptyQueryResponse => {
305                        // empty query string passed to an unprepared execute
306                    }
307
308                    // Message::ErrorResponse is handled in self.stream.recv()
309
310                    // incomplete query execution has finished
311                    MessageFormat::PortalSuspended => {}
312
313                    MessageFormat::RowDescription => {
314                        // indicates that a *new* set of rows are about to be returned
315                        let (columns, column_names) = self
316                            .handle_row_description(Some(message.decode()?), false)
317                            .await?;
318
319                        metadata = Arc::new(PgStatementMetadata {
320                            column_names,
321                            columns,
322                            parameters: Vec::default(),
323                        });
324                    }
325
326                    MessageFormat::DataRow => {
327                        logger.increment_rows_returned();
328
329                        // one of the set of rows returned by a SELECT, FETCH, etc query
330                        let data: DataRow = message.decode()?;
331                        let row = PgRow {
332                            data,
333                            format,
334                            metadata: Arc::clone(&metadata),
335                        };
336
337                        r#yield!(Either::Right(row));
338                    }
339
340                    MessageFormat::ReadyForQuery => {
341                        // processing of the query string is complete
342                        self.handle_ready_for_query(message)?;
343                        break;
344                    }
345
346                    _ => {
347                        return Err(err_protocol!(
348                            "execute: unexpected message: {:?}",
349                            message.format
350                        ));
351                    }
352                }
353            }
354
355            Ok(())
356        })
357    }
358}
359
360impl<'c> Executor<'c> for &'c mut PgConnection {
361    type Database = Postgres;
362
363    fn fetch_many<'e, 'q: 'e, E: 'q>(
364        self,
365        mut query: E,
366    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
367    where
368        'c: 'e,
369        E: Execute<'q, Self::Database>,
370    {
371        let sql = query.sql();
372        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
373        let arguments = query.take_arguments();
374        let persistent = query.persistent();
375
376        Box::pin(try_stream! {
377            let s = self.run(sql, arguments, 0, persistent, metadata).await?;
378            pin_mut!(s);
379
380            while let Some(v) = s.try_next().await? {
381                r#yield!(v);
382            }
383
384            Ok(())
385        })
386    }
387
388    fn fetch_optional<'e, 'q: 'e, E: 'q>(
389        self,
390        mut query: E,
391    ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
392    where
393        'c: 'e,
394        E: Execute<'q, Self::Database>,
395    {
396        let sql = query.sql();
397        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
398        let arguments = query.take_arguments();
399        let persistent = query.persistent();
400
401        Box::pin(async move {
402            let s = self.run(sql, arguments, 1, persistent, metadata).await?;
403            pin_mut!(s);
404
405            while let Some(s) = s.try_next().await? {
406                if let Either::Right(r) = s {
407                    return Ok(Some(r));
408                }
409            }
410
411            Ok(None)
412        })
413    }
414
415    fn prepare_with<'e, 'q: 'e>(
416        self,
417        sql: &'q str,
418        parameters: &'e [PgTypeInfo],
419    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
420    where
421        'c: 'e,
422    {
423        Box::pin(async move {
424            self.wait_until_ready().await?;
425
426            let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
427
428            Ok(PgStatement {
429                sql: Cow::Borrowed(sql),
430                metadata,
431            })
432        })
433    }
434
435    fn describe<'e, 'q: 'e>(
436        self,
437        sql: &'q str,
438    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
439    where
440        'c: 'e,
441    {
442        Box::pin(async move {
443            self.wait_until_ready().await?;
444
445            let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
446
447            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
448
449            Ok(Describe {
450                columns: metadata.columns.clone(),
451                nullable,
452                parameters: Some(Either::Left(metadata.parameters.clone())),
453            })
454        })
455    }
456}