sqlx_postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8    ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13    PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::TryStreamExt;
19use sqlx_core::arguments::Arguments;
20use sqlx_core::sql_str::SqlStr;
21use sqlx_core::Either;
22use std::{pin::pin, sync::Arc};
23
24async fn prepare(
25    conn: &mut PgConnection,
26    sql: &str,
27    parameters: &[PgTypeInfo],
28    metadata: Option<Arc<PgStatementMetadata>>,
29    persistent: bool,
30    fetch_column_origin: bool,
31) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
32    let id = if persistent {
33        let id = conn.inner.next_statement_id;
34        conn.inner.next_statement_id = id.next();
35        id
36    } else {
37        StatementId::UNNAMED
38    };
39
40    let mut param_types = Vec::with_capacity(parameters.len());
45
46    for ty in parameters {
47        param_types.push(conn.resolve_type_id(&ty.0).await?);
48    }
49
50    conn.wait_until_ready().await?;
52
53    conn.inner.stream.write_msg(Parse {
55        param_types: ¶m_types,
56        query: sql,
57        statement: id,
58    })?;
59
60    if metadata.is_none() {
61        conn.inner
63            .stream
64            .write_msg(message::Describe::Statement(id))?;
65    }
66
67    conn.write_sync();
69    conn.inner.stream.flush().await?;
70
71    conn.inner.stream.recv_expect::<ParseComplete>().await?;
73
74    let metadata = if let Some(metadata) = metadata {
75        conn.recv_ready_for_query().await?;
77
78        metadata
80    } else {
81        let parameters = recv_desc_params(conn).await?;
82
83        let rows = recv_desc_rows(conn).await?;
84
85        conn.recv_ready_for_query().await?;
87
88        let parameters = conn.handle_parameter_description(parameters).await?;
89
90        let (columns, column_names) = conn
91            .handle_row_description(rows, true, fetch_column_origin)
92            .await?;
93
94        conn.wait_until_ready().await?;
97
98        Arc::new(PgStatementMetadata {
99            parameters,
100            columns,
101            column_names: Arc::new(column_names),
102        })
103    };
104
105    Ok((id, metadata))
106}
107
108async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
109    conn.inner.stream.recv_expect().await
110}
111
112async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
113    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
114        message if message.format == BackendMessageFormat::RowDescription => {
116            Some(message.decode()?)
117        }
118
119        message if message.format == BackendMessageFormat::NoData => None,
121
122        message => {
123            return Err(err_protocol!(
124                "expecting RowDescription or NoData but received {:?}",
125                message.format
126            ));
127        }
128    };
129
130    Ok(rows)
131}
132
133impl PgConnection {
134    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
136        while count > 0 {
138            match self.inner.stream.recv().await? {
139                message if message.format == BackendMessageFormat::PortalSuspended => {
140                    }
143
144                message if message.format == BackendMessageFormat::CloseComplete => {
145                    count -= 1;
147                }
148
149                message => {
150                    return Err(err_protocol!(
151                        "expecting PortalSuspended or CloseComplete but received {:?}",
152                        message.format
153                    ));
154                }
155            }
156        }
157
158        Ok(())
159    }
160
161    #[inline(always)]
162    pub(crate) fn write_sync(&mut self) {
163        self.inner
164            .stream
165            .write_msg(message::Sync)
166            .expect("BUG: Sync should not be too big for protocol");
167
168        self.inner.pending_ready_for_query_count += 1;
170    }
171
172    async fn get_or_prepare(
173        &mut self,
174        sql: &str,
175        parameters: &[PgTypeInfo],
176        persistent: bool,
177        metadata: Option<Arc<PgStatementMetadata>>,
180        fetch_column_origin: bool,
181    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
182        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
183            return Ok((*statement).clone());
184        }
185
186        let statement = prepare(
187            self,
188            sql,
189            parameters,
190            metadata,
191            persistent,
192            fetch_column_origin,
193        )
194        .await?;
195
196        if persistent && self.inner.cache_statement.is_enabled() {
197            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
198                self.inner.stream.write_msg(Close::Statement(id))?;
199                self.write_sync();
200
201                self.inner.stream.flush().await?;
202
203                self.wait_for_close_complete(1).await?;
204                self.recv_ready_for_query().await?;
205            }
206        }
207
208        Ok(statement)
209    }
210
211    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
212        &'c mut self,
213        query: SqlStr,
214        arguments: Option<PgArguments>,
215        persistent: bool,
216        metadata_opt: Option<Arc<PgStatementMetadata>>,
217    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
218        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
219        let sql = logger.sql().as_str();
220
221        self.wait_until_ready().await?;
223
224        let mut metadata: Arc<PgStatementMetadata>;
225
226        let format = if let Some(mut arguments) = arguments {
227            let num_params = u16::try_from(arguments.len()).map_err(|_| {
234                err_protocol!(
235                    "PgConnection::run(): too many arguments for query: {}",
236                    arguments.len()
237                )
238            })?;
239
240            let (statement, metadata_) = self
243                .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
244                .await?;
245
246            metadata = metadata_;
247
248            arguments.apply_patches(self, &metadata.parameters).await?;
250
251            self.wait_until_ready().await?;
253
254            self.inner.stream.write_msg(Bind {
256                portal: PortalId::UNNAMED,
257                statement,
258                formats: &[PgValueFormat::Binary],
259                num_params,
260                params: &arguments.buffer,
261                result_formats: &[PgValueFormat::Binary],
262            })?;
263
264            self.inner.stream.write_msg(message::Execute {
267                portal: PortalId::UNNAMED,
268                limit: 0,
271            })?;
272            self.inner
282                .stream
283                .write_msg(Close::Portal(PortalId::UNNAMED))?;
284
285            self.write_sync();
291
292            PgValueFormat::Binary
294        } else {
295            self.inner.stream.write_msg(Query(sql))?;
297            self.inner.pending_ready_for_query_count += 1;
298
299            metadata = Arc::new(PgStatementMetadata::default());
301
302            PgValueFormat::Text
304        };
305
306        self.inner.stream.flush().await?;
307
308        Ok(try_stream! {
309            loop {
310                let message = self.inner.stream.recv().await?;
311
312                match message.format {
313                    BackendMessageFormat::BindComplete
314                    | BackendMessageFormat::ParseComplete
315                    | BackendMessageFormat::ParameterDescription
316                    | BackendMessageFormat::NoData
317                    | BackendMessageFormat::CloseComplete
319                    => {
320                        }
322
323                    BackendMessageFormat::CommandComplete => {
328                        let cc: CommandComplete = message.decode()?;
330
331                        let rows_affected = cc.rows_affected();
332                        logger.increase_rows_affected(rows_affected);
333                        r#yield!(Either::Left(PgQueryResult {
334                            rows_affected,
335                        }));
336                    }
337
338                    BackendMessageFormat::EmptyQueryResponse => {
339                        }
341
342                    BackendMessageFormat::PortalSuspended => {}
346
347                    BackendMessageFormat::RowDescription => {
348                        let (columns, column_names) = self
350                            .handle_row_description(Some(message.decode()?), false, false)
351                            .await?;
352
353                        metadata = Arc::new(PgStatementMetadata {
354                            column_names: Arc::new(column_names),
355                            columns,
356                            parameters: Vec::default(),
357                        });
358                    }
359
360                    BackendMessageFormat::DataRow => {
361                        logger.increment_rows_returned();
362
363                        let data: DataRow = message.decode()?;
365                        let row = PgRow {
366                            data,
367                            format,
368                            metadata: Arc::clone(&metadata),
369                        };
370
371                        r#yield!(Either::Right(row));
372                    }
373
374                    BackendMessageFormat::ReadyForQuery => {
375                        self.handle_ready_for_query(message)?;
377                        break;
378                    }
379
380                    _ => {
381                        return Err(err_protocol!(
382                            "execute: unexpected message: {:?}",
383                            message.format
384                        ));
385                    }
386                }
387            }
388
389            Ok(())
390        })
391    }
392}
393
394impl<'c> Executor<'c> for &'c mut PgConnection {
395    type Database = Postgres;
396
397    fn fetch_many<'e, 'q, E>(
398        self,
399        mut query: E,
400    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
401    where
402        'c: 'e,
403        E: Execute<'q, Self::Database>,
404        'q: 'e,
405        E: 'q,
406    {
407        #[allow(clippy::map_clone)]
409        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
410        let arguments = query.take_arguments().map_err(Error::Encode);
411        let persistent = query.persistent();
412        let sql = query.sql();
413
414        Box::pin(try_stream! {
415            let arguments = arguments?;
416            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
417
418            while let Some(v) = s.try_next().await? {
419                r#yield!(v);
420            }
421
422            Ok(())
423        })
424    }
425
426    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
427    where
428        'c: 'e,
429        E: Execute<'q, Self::Database>,
430        'q: 'e,
431        E: 'q,
432    {
433        #[allow(clippy::map_clone)]
435        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
436        let arguments = query.take_arguments().map_err(Error::Encode);
437        let persistent = query.persistent();
438
439        Box::pin(async move {
440            let sql = query.sql();
441            let arguments = arguments?;
442            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
443
444            let mut ret = None;
449            while let Some(result) = s.try_next().await? {
450                match result {
451                    Either::Right(r) if ret.is_none() => ret = Some(r),
452                    _ => {}
453                }
454            }
455            Ok(ret)
456        })
457    }
458
459    fn prepare_with<'e>(
460        self,
461        sql: SqlStr,
462        parameters: &'e [PgTypeInfo],
463    ) -> BoxFuture<'e, Result<PgStatement, Error>>
464    where
465        'c: 'e,
466    {
467        Box::pin(async move {
468            self.wait_until_ready().await?;
469
470            let (_, metadata) = self
471                .get_or_prepare(sql.as_str(), parameters, true, None, true)
472                .await?;
473
474            Ok(PgStatement { sql, metadata })
475        })
476    }
477
478    fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
479    where
480        'c: 'e,
481    {
482        Box::pin(async move {
483            self.wait_until_ready().await?;
484
485            let (stmt_id, metadata) = self
486                .get_or_prepare(sql.as_str(), &[], true, None, true)
487                .await?;
488
489            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
490
491            Ok(Describe {
492                columns: metadata.columns.clone(),
493                nullable,
494                parameters: Some(Either::Left(metadata.parameters.clone())),
495            })
496        })
497    }
498}