sqlx_core_guts/postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::logger::QueryLogger;
5use crate::postgres::message::{
6 self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
7 RowDescription,
8};
9use crate::postgres::statement::PgStatementMetadata;
10use crate::postgres::type_info::PgType;
11use crate::postgres::types::Oid;
12use crate::postgres::{
13 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
14 PgValueFormat, Postgres,
15};
16use either::Either;
17use futures_core::future::BoxFuture;
18use futures_core::stream::BoxStream;
19use futures_core::Stream;
20use futures_util::{pin_mut, TryStreamExt};
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
29 let id = conn.next_statement_id;
30 conn.next_statement_id.incr_one();
31
32 let mut param_types = Vec::with_capacity(parameters.len());
37
38 for ty in parameters {
39 param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
40 conn.fetch_type_id_by_name(name).await?
41 } else {
42 ty.0.oid()
43 });
44 }
45
46 conn.wait_until_ready().await?;
48
49 conn.stream.write(Parse {
51 param_types: &*param_types,
52 query: sql,
53 statement: id,
54 });
55
56 if metadata.is_none() {
57 conn.stream.write(message::Describe::Statement(id));
59 }
60
61 conn.write_sync();
63 conn.stream.flush().await?;
64
65 let _ = conn
67 .stream
68 .recv_expect(MessageFormat::ParseComplete)
69 .await?;
70
71 let metadata = if let Some(metadata) = metadata {
72 conn.recv_ready_for_query().await?;
74
75 metadata
77 } else {
78 let parameters = recv_desc_params(conn).await?;
79
80 let rows = recv_desc_rows(conn).await?;
81
82 conn.recv_ready_for_query().await?;
84
85 let parameters = conn.handle_parameter_description(parameters).await?;
86
87 let (columns, column_names) = conn.handle_row_description(rows, true).await?;
88
89 conn.wait_until_ready().await?;
92
93 Arc::new(PgStatementMetadata {
94 parameters,
95 columns,
96 column_names,
97 })
98 };
99
100 Ok((id, metadata))
101}
102
103async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
104 conn.stream
105 .recv_expect(MessageFormat::ParameterDescription)
106 .await
107}
108
109async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
110 let rows: Option<RowDescription> = match conn.stream.recv().await? {
111 message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
113
114 message if message.format == MessageFormat::NoData => None,
116
117 message => {
118 return Err(err_protocol!(
119 "expecting RowDescription or NoData but received {:?}",
120 message.format
121 ));
122 }
123 };
124
125 Ok(rows)
126}
127
128impl PgConnection {
129 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
131 while count > 0 {
133 match self.stream.recv().await? {
134 message if message.format == MessageFormat::PortalSuspended => {
135 }
139
140 message if message.format == MessageFormat::CloseComplete => {
141 count -= 1;
143 }
144
145 message => {
146 return Err(err_protocol!(
147 "expecting PortalSuspended or CloseComplete but received {:?}",
148 message.format
149 ));
150 }
151 }
152 }
153
154 Ok(())
155 }
156
157 pub(crate) fn write_sync(&mut self) {
158 self.stream.write(message::Sync);
159
160 self.pending_ready_for_query_count += 1;
162 }
163
164 async fn get_or_prepare<'a>(
165 &mut self,
166 sql: &str,
167 parameters: &[PgTypeInfo],
168 store_to_cache: bool,
170 metadata: Option<Arc<PgStatementMetadata>>,
173 ) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
174 if let Some(statement) = self.cache_statement.get_mut(sql) {
175 return Ok((*statement).clone());
176 }
177
178 let statement = prepare(self, sql, parameters, metadata).await?;
179
180 if store_to_cache && self.cache_statement.is_enabled() {
181 if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
182 self.stream.write(Close::Statement(id));
183 self.write_sync();
184
185 self.stream.flush().await?;
186
187 self.wait_for_close_complete(1).await?;
188 self.recv_ready_for_query().await?;
189 }
190 }
191
192 Ok(statement)
193 }
194
195 async fn run<'e, 'c: 'e, 'q: 'e>(
196 &'c mut self,
197 query: &'q str,
198 arguments: Option<PgArguments>,
199 limit: u8,
200 persistent: bool,
201 metadata_opt: Option<Arc<PgStatementMetadata>>,
202 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
203 let mut logger = QueryLogger::new(query, self.log_settings.clone());
204
205 self.wait_until_ready().await?;
207
208 let mut metadata: Arc<PgStatementMetadata>;
209
210 let format = if let Some(mut arguments) = arguments {
211 let (statement, metadata_) = self
214 .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
215 .await?;
216
217 metadata = metadata_;
218
219 arguments.apply_patches(self, &metadata.parameters).await?;
221
222 self.wait_until_ready().await?;
225
226 self.stream.write(Bind {
228 portal: None,
229 statement,
230 formats: &[PgValueFormat::Binary],
231 num_params: arguments.types.len() as i16,
232 params: &*arguments.buffer,
233 result_formats: &[PgValueFormat::Binary],
234 });
235
236 self.stream.write(message::Execute {
239 portal: None,
240 limit: limit.into(),
241 });
242
243 self.write_sync();
249
250 PgValueFormat::Binary
252 } else {
253 self.stream.write(Query(query));
255 self.pending_ready_for_query_count += 1;
256
257 metadata = Arc::new(PgStatementMetadata::default());
259
260 PgValueFormat::Text
262 };
263
264 self.stream.flush().await?;
265
266 Ok(try_stream! {
267 loop {
268 let message = self.stream.recv().await?;
269
270 match message.format {
271 MessageFormat::BindComplete
272 | MessageFormat::ParseComplete
273 | MessageFormat::ParameterDescription
274 | MessageFormat::NoData => {
275 }
277
278 MessageFormat::CommandComplete => {
279 let cc: CommandComplete = message.decode()?;
281
282 let rows_affected = cc.rows_affected();
283 logger.increase_rows_affected(rows_affected);
284 r#yield!(Either::Left(PgQueryResult {
285 rows_affected,
286 }));
287 }
288
289 MessageFormat::EmptyQueryResponse => {
290 }
292
293 MessageFormat::RowDescription => {
294 let (columns, column_names) = self
296 .handle_row_description(Some(message.decode()?), false)
297 .await?;
298
299 metadata = Arc::new(PgStatementMetadata {
300 column_names,
301 columns,
302 parameters: Vec::default(),
303 });
304 }
305
306 MessageFormat::DataRow => {
307 logger.increment_rows_returned();
308
309 let data: DataRow = message.decode()?;
311 let row = PgRow {
312 data,
313 format,
314 metadata: Arc::clone(&metadata),
315 };
316
317 r#yield!(Either::Right(row));
318 }
319
320 MessageFormat::ReadyForQuery => {
321 self.handle_ready_for_query(message)?;
323 break;
324 }
325
326 _ => {
327 return Err(err_protocol!(
328 "execute: unexpected message: {:?}",
329 message.format
330 ));
331 }
332 }
333 }
334
335 Ok(())
336 })
337 }
338}
339
340impl<'c> Executor<'c> for &'c mut PgConnection {
341 type Database = Postgres;
342
343 fn fetch_many<'e, 'q: 'e, E: 'q>(
344 self,
345 mut query: E,
346 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
347 where
348 'c: 'e,
349 E: Execute<'q, Self::Database>,
350 {
351 let sql = query.sql();
352 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
353 let arguments = query.take_arguments();
354 let persistent = query.persistent();
355
356 Box::pin(try_stream! {
357 let s = self.run(sql, arguments, 0, persistent, metadata).await?;
358 pin_mut!(s);
359
360 while let Some(v) = s.try_next().await? {
361 r#yield!(v);
362 }
363
364 Ok(())
365 })
366 }
367
368 fn fetch_optional<'e, 'q: 'e, E: 'q>(
369 self,
370 mut query: E,
371 ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
372 where
373 'c: 'e,
374 E: Execute<'q, Self::Database>,
375 {
376 let sql = query.sql();
377 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
378 let arguments = query.take_arguments();
379 let persistent = query.persistent();
380
381 Box::pin(async move {
382 let s = self.run(sql, arguments, 1, persistent, metadata).await?;
383 pin_mut!(s);
384
385 while let Some(s) = s.try_next().await? {
386 if let Either::Right(r) = s {
387 return Ok(Some(r));
388 }
389 }
390
391 Ok(None)
392 })
393 }
394
395 fn prepare_with<'e, 'q: 'e>(
396 self,
397 sql: &'q str,
398 parameters: &'e [PgTypeInfo],
399 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
400 where
401 'c: 'e,
402 {
403 Box::pin(async move {
404 self.wait_until_ready().await?;
405
406 let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
407
408 Ok(PgStatement {
409 sql: Cow::Borrowed(sql),
410 metadata,
411 })
412 })
413 }
414
415 fn describe<'e, 'q: 'e>(
416 self,
417 sql: &'q str,
418 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
419 where
420 'c: 'e,
421 {
422 Box::pin(async move {
423 self.wait_until_ready().await?;
424
425 let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
426
427 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
428
429 Ok(Describe {
430 columns: metadata.columns.clone(),
431 nullable,
432 parameters: Some(Either::Left(metadata.parameters.clone())),
433 })
434 })
435 }
436}