sqlx_core_guts/postgres/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::logger::QueryLogger;
5use crate::postgres::message::{
6    self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
7    RowDescription,
8};
9use crate::postgres::statement::PgStatementMetadata;
10use crate::postgres::type_info::PgType;
11use crate::postgres::types::Oid;
12use crate::postgres::{
13    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
14    PgValueFormat, Postgres,
15};
16use either::Either;
17use futures_core::future::BoxFuture;
18use futures_core::stream::BoxStream;
19use futures_core::Stream;
20use futures_util::{pin_mut, TryStreamExt};
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
29    let id = conn.next_statement_id;
30    conn.next_statement_id.incr_one();
31
32    // build a list of type OIDs to send to the database in the PARSE command
33    // we have not yet started the query sequence, so we are *safe* to cleanly make
34    // additional queries here to get any missing OIDs
35
36    let mut param_types = Vec::with_capacity(parameters.len());
37
38    for ty in parameters {
39        param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
40            conn.fetch_type_id_by_name(name).await?
41        } else {
42            ty.0.oid()
43        });
44    }
45
46    // flush and wait until we are re-ready
47    conn.wait_until_ready().await?;
48
49    // next we send the PARSE command to the server
50    conn.stream.write(Parse {
51        param_types: &*param_types,
52        query: sql,
53        statement: id,
54    });
55
56    if metadata.is_none() {
57        // get the statement columns and parameters
58        conn.stream.write(message::Describe::Statement(id));
59    }
60
61    // we ask for the server to immediately send us the result of the PARSE command
62    conn.write_sync();
63    conn.stream.flush().await?;
64
65    // indicates that the SQL query string is now successfully parsed and has semantic validity
66    let _ = conn
67        .stream
68        .recv_expect(MessageFormat::ParseComplete)
69        .await?;
70
71    let metadata = if let Some(metadata) = metadata {
72        // each SYNC produces one READY FOR QUERY
73        conn.recv_ready_for_query().await?;
74
75        // we already have metadata
76        metadata
77    } else {
78        let parameters = recv_desc_params(conn).await?;
79
80        let rows = recv_desc_rows(conn).await?;
81
82        // each SYNC produces one READY FOR QUERY
83        conn.recv_ready_for_query().await?;
84
85        let parameters = conn.handle_parameter_description(parameters).await?;
86
87        let (columns, column_names) = conn.handle_row_description(rows, true).await?;
88
89        // ensure that if we did fetch custom data, we wait until we are fully ready before
90        // continuing
91        conn.wait_until_ready().await?;
92
93        Arc::new(PgStatementMetadata {
94            parameters,
95            columns,
96            column_names,
97        })
98    };
99
100    Ok((id, metadata))
101}
102
103async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
104    conn.stream
105        .recv_expect(MessageFormat::ParameterDescription)
106        .await
107}
108
109async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
110    let rows: Option<RowDescription> = match conn.stream.recv().await? {
111        // describes the rows that will be returned when the statement is eventually executed
112        message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
113
114        // no data would be returned if this statement was executed
115        message if message.format == MessageFormat::NoData => None,
116
117        message => {
118            return Err(err_protocol!(
119                "expecting RowDescription or NoData but received {:?}",
120                message.format
121            ));
122        }
123    };
124
125    Ok(rows)
126}
127
128impl PgConnection {
129    // wait for CloseComplete to indicate a statement was closed
130    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
131        // we need to wait for the [CloseComplete] to be returned from the server
132        while count > 0 {
133            match self.stream.recv().await? {
134                message if message.format == MessageFormat::PortalSuspended => {
135                    // there was an open portal
136                    // this can happen if the last time a statement was used it was not fully executed
137                    // such as in [fetch_one]
138                }
139
140                message if message.format == MessageFormat::CloseComplete => {
141                    // successfully closed the statement (and freed up the server resources)
142                    count -= 1;
143                }
144
145                message => {
146                    return Err(err_protocol!(
147                        "expecting PortalSuspended or CloseComplete but received {:?}",
148                        message.format
149                    ));
150                }
151            }
152        }
153
154        Ok(())
155    }
156
157    pub(crate) fn write_sync(&mut self) {
158        self.stream.write(message::Sync);
159
160        // all SYNC messages will return a ReadyForQuery
161        self.pending_ready_for_query_count += 1;
162    }
163
164    async fn get_or_prepare<'a>(
165        &mut self,
166        sql: &str,
167        parameters: &[PgTypeInfo],
168        // should we store the result of this prepare to the cache
169        store_to_cache: bool,
170        // optional metadata that was provided by the user, this means they are reusing
171        // a statement object
172        metadata: Option<Arc<PgStatementMetadata>>,
173    ) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
174        if let Some(statement) = self.cache_statement.get_mut(sql) {
175            return Ok((*statement).clone());
176        }
177
178        let statement = prepare(self, sql, parameters, metadata).await?;
179
180        if store_to_cache && self.cache_statement.is_enabled() {
181            if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
182                self.stream.write(Close::Statement(id));
183                self.write_sync();
184
185                self.stream.flush().await?;
186
187                self.wait_for_close_complete(1).await?;
188                self.recv_ready_for_query().await?;
189            }
190        }
191
192        Ok(statement)
193    }
194
195    async fn run<'e, 'c: 'e, 'q: 'e>(
196        &'c mut self,
197        query: &'q str,
198        arguments: Option<PgArguments>,
199        limit: u8,
200        persistent: bool,
201        metadata_opt: Option<Arc<PgStatementMetadata>>,
202    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
203        let mut logger = QueryLogger::new(query, self.log_settings.clone());
204
205        // before we continue, wait until we are "ready" to accept more queries
206        self.wait_until_ready().await?;
207
208        let mut metadata: Arc<PgStatementMetadata>;
209
210        let format = if let Some(mut arguments) = arguments {
211            // prepare the statement if this our first time executing it
212            // always return the statement ID here
213            let (statement, metadata_) = self
214                .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
215                .await?;
216
217            metadata = metadata_;
218
219            // patch holes created during encoding
220            arguments.apply_patches(self, &metadata.parameters).await?;
221
222            // apply patches use fetch_optional thaht may produce `PortalSuspended` message,
223            // consume messages til `ReadyForQuery` before bind and execute
224            self.wait_until_ready().await?;
225
226            // bind to attach the arguments to the statement and create a portal
227            self.stream.write(Bind {
228                portal: None,
229                statement,
230                formats: &[PgValueFormat::Binary],
231                num_params: arguments.types.len() as i16,
232                params: &*arguments.buffer,
233                result_formats: &[PgValueFormat::Binary],
234            });
235
236            // executes the portal up to the passed limit
237            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
238            self.stream.write(message::Execute {
239                portal: None,
240                limit: limit.into(),
241            });
242
243            // finally, [Sync] asks postgres to process the messages that we sent and respond with
244            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
245            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
246            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
247            // termed batching might suit this.
248            self.write_sync();
249
250            // prepared statements are binary
251            PgValueFormat::Binary
252        } else {
253            // Query will trigger a ReadyForQuery
254            self.stream.write(Query(query));
255            self.pending_ready_for_query_count += 1;
256
257            // metadata starts out as "nothing"
258            metadata = Arc::new(PgStatementMetadata::default());
259
260            // and unprepared statements are text
261            PgValueFormat::Text
262        };
263
264        self.stream.flush().await?;
265
266        Ok(try_stream! {
267            loop {
268                let message = self.stream.recv().await?;
269
270                match message.format {
271                    MessageFormat::BindComplete
272                    | MessageFormat::ParseComplete
273                    | MessageFormat::ParameterDescription
274                    | MessageFormat::NoData => {
275                        // harmless messages to ignore
276                    }
277
278                    MessageFormat::CommandComplete => {
279                        // a SQL command completed normally
280                        let cc: CommandComplete = message.decode()?;
281
282                        let rows_affected = cc.rows_affected();
283                        logger.increase_rows_affected(rows_affected);
284                        r#yield!(Either::Left(PgQueryResult {
285                            rows_affected,
286                        }));
287                    }
288
289                    MessageFormat::EmptyQueryResponse => {
290                        // empty query string passed to an unprepared execute
291                    }
292
293                    MessageFormat::RowDescription => {
294                        // indicates that a *new* set of rows are about to be returned
295                        let (columns, column_names) = self
296                            .handle_row_description(Some(message.decode()?), false)
297                            .await?;
298
299                        metadata = Arc::new(PgStatementMetadata {
300                            column_names,
301                            columns,
302                            parameters: Vec::default(),
303                        });
304                    }
305
306                    MessageFormat::DataRow => {
307                        logger.increment_rows_returned();
308
309                        // one of the set of rows returned by a SELECT, FETCH, etc query
310                        let data: DataRow = message.decode()?;
311                        let row = PgRow {
312                            data,
313                            format,
314                            metadata: Arc::clone(&metadata),
315                        };
316
317                        r#yield!(Either::Right(row));
318                    }
319
320                    MessageFormat::ReadyForQuery => {
321                        // processing of the query string is complete
322                        self.handle_ready_for_query(message)?;
323                        break;
324                    }
325
326                    _ => {
327                        return Err(err_protocol!(
328                            "execute: unexpected message: {:?}",
329                            message.format
330                        ));
331                    }
332                }
333            }
334
335            Ok(())
336        })
337    }
338}
339
340impl<'c> Executor<'c> for &'c mut PgConnection {
341    type Database = Postgres;
342
343    fn fetch_many<'e, 'q: 'e, E: 'q>(
344        self,
345        mut query: E,
346    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
347    where
348        'c: 'e,
349        E: Execute<'q, Self::Database>,
350    {
351        let sql = query.sql();
352        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
353        let arguments = query.take_arguments();
354        let persistent = query.persistent();
355
356        Box::pin(try_stream! {
357            let s = self.run(sql, arguments, 0, persistent, metadata).await?;
358            pin_mut!(s);
359
360            while let Some(v) = s.try_next().await? {
361                r#yield!(v);
362            }
363
364            Ok(())
365        })
366    }
367
368    fn fetch_optional<'e, 'q: 'e, E: 'q>(
369        self,
370        mut query: E,
371    ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
372    where
373        'c: 'e,
374        E: Execute<'q, Self::Database>,
375    {
376        let sql = query.sql();
377        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
378        let arguments = query.take_arguments();
379        let persistent = query.persistent();
380
381        Box::pin(async move {
382            let s = self.run(sql, arguments, 1, persistent, metadata).await?;
383            pin_mut!(s);
384
385            while let Some(s) = s.try_next().await? {
386                if let Either::Right(r) = s {
387                    return Ok(Some(r));
388                }
389            }
390
391            Ok(None)
392        })
393    }
394
395    fn prepare_with<'e, 'q: 'e>(
396        self,
397        sql: &'q str,
398        parameters: &'e [PgTypeInfo],
399    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
400    where
401        'c: 'e,
402    {
403        Box::pin(async move {
404            self.wait_until_ready().await?;
405
406            let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
407
408            Ok(PgStatement {
409                sql: Cow::Borrowed(sql),
410                metadata,
411            })
412        })
413    }
414
415    fn describe<'e, 'q: 'e>(
416        self,
417        sql: &'q str,
418    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
419    where
420        'c: 'e,
421    {
422        Box::pin(async move {
423            self.wait_until_ready().await?;
424
425            let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
426
427            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
428
429            Ok(Describe {
430                columns: metadata.columns.clone(),
431                nullable,
432                parameters: Some(Either::Left(metadata.parameters.clone())),
433            })
434        })
435    }
436}