sqlx_core_oldapi/postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::logger::QueryLogger;
5use crate::postgres::message::{
6 self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
7 RowDescription,
8};
9use crate::postgres::statement::PgStatementMetadata;
10use crate::postgres::type_info::PgType;
11use crate::postgres::types::Oid;
12use crate::postgres::{
13 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
14 PgValueFormat, Postgres,
15};
16use either::Either;
17use futures_core::future::BoxFuture;
18use futures_core::stream::BoxStream;
19use futures_core::Stream;
20use futures_util::{pin_mut, TryStreamExt};
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
29 let id = conn.next_statement_id;
30 conn.next_statement_id.incr_one();
31
32 let mut param_types = Vec::with_capacity(parameters.len());
37
38 for ty in parameters {
39 param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
40 conn.fetch_type_id_by_name(name).await?
41 } else {
42 ty.0.oid()
43 });
44 }
45
46 conn.wait_until_ready().await?;
48
49 conn.stream.write(Parse {
51 param_types: &*param_types,
52 query: sql,
53 statement: id,
54 });
55
56 if metadata.is_none() {
57 conn.stream.write(message::Describe::Statement(id));
59 }
60
61 conn.write_sync();
63 conn.stream.flush().await?;
64
65 let _: () = conn
67 .stream
68 .recv_expect(MessageFormat::ParseComplete)
69 .await?;
70
71 let metadata = if let Some(metadata) = metadata {
72 conn.recv_ready_for_query().await?;
74
75 metadata
77 } else {
78 let parameters = recv_desc_params(conn).await?;
79
80 let rows = recv_desc_rows(conn).await?;
81
82 conn.recv_ready_for_query().await?;
84
85 let parameters = conn.handle_parameter_description(parameters).await?;
86
87 let (columns, column_names) = conn.handle_row_description(rows, true).await?;
88
89 conn.wait_until_ready().await?;
92
93 Arc::new(PgStatementMetadata {
94 parameters,
95 columns,
96 column_names,
97 })
98 };
99
100 Ok((id, metadata))
101}
102
103async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
104 conn.stream
105 .recv_expect(MessageFormat::ParameterDescription)
106 .await
107}
108
109async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
110 let rows: Option<RowDescription> = match conn.stream.recv().await? {
111 message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
113
114 message if message.format == MessageFormat::NoData => None,
116
117 message => {
118 return Err(err_protocol!(
119 "expecting RowDescription or NoData but received {:?}",
120 message.format
121 ));
122 }
123 };
124
125 Ok(rows)
126}
127
128impl PgConnection {
129 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
131 while count > 0 {
133 match self.stream.recv().await? {
134 message if message.format == MessageFormat::PortalSuspended => {
135 }
138
139 message if message.format == MessageFormat::CloseComplete => {
140 count -= 1;
142 }
143
144 message => {
145 return Err(err_protocol!(
146 "expecting PortalSuspended or CloseComplete but received {:?}",
147 message.format
148 ));
149 }
150 }
151 }
152
153 Ok(())
154 }
155
156 pub(crate) fn write_sync(&mut self) {
157 self.stream.write(message::Sync);
158
159 self.pending_ready_for_query_count += 1;
161 }
162
163 async fn get_or_prepare<'a>(
164 &mut self,
165 sql: &str,
166 parameters: &[PgTypeInfo],
167 store_to_cache: bool,
169 metadata: Option<Arc<PgStatementMetadata>>,
172 ) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
173 if let Some(statement) = self.cache_statement.get_mut(sql) {
174 return Ok((*statement).clone());
175 }
176
177 let statement = prepare(self, sql, parameters, metadata).await?;
178
179 if store_to_cache && self.cache_statement.is_enabled() {
180 if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
181 self.stream.write(Close::Statement(id));
182 self.write_sync();
183
184 self.stream.flush().await?;
185
186 self.wait_for_close_complete(1).await?;
187 self.recv_ready_for_query().await?;
188 }
189 }
190
191 Ok(statement)
192 }
193
194 async fn run<'e, 'c: 'e, 'q: 'e>(
195 &'c mut self,
196 query: &'q str,
197 arguments: Option<PgArguments>,
198 limit: u8,
199 persistent: bool,
200 metadata_opt: Option<Arc<PgStatementMetadata>>,
201 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
202 let mut logger = QueryLogger::new(query, self.log_settings.clone());
203
204 self.wait_until_ready().await?;
206
207 let mut metadata: Arc<PgStatementMetadata>;
208
209 let format = if let Some(mut arguments) = arguments {
210 let (statement, metadata_) = self
213 .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
214 .await?;
215
216 metadata = metadata_;
217
218 arguments.apply_patches(self, &metadata.parameters).await?;
220
221 self.wait_until_ready().await?;
223
224 self.stream.write(Bind {
226 portal: None,
227 statement,
228 formats: &[PgValueFormat::Binary],
229 num_params: arguments.types.len() as i16,
230 params: &*arguments.buffer,
231 result_formats: &[PgValueFormat::Binary],
232 });
233
234 self.stream.write(message::Execute {
237 portal: None,
238 limit: limit.into(),
239 });
240 self.stream.write(message::Close::Portal(None));
250
251 self.write_sync();
257
258 PgValueFormat::Binary
260 } else {
261 self.stream.write(Query(query));
263 self.pending_ready_for_query_count += 1;
264
265 metadata = Arc::new(PgStatementMetadata::default());
267
268 PgValueFormat::Text
270 };
271
272 self.stream.flush().await?;
273
274 Ok(try_stream! {
275 loop {
276 let message = self.stream.recv().await?;
277
278 match message.format {
279 MessageFormat::BindComplete
280 | MessageFormat::ParseComplete
281 | MessageFormat::ParameterDescription
282 | MessageFormat::NoData
283 | MessageFormat::CloseComplete
285 => {
286 }
288
289 MessageFormat::CommandComplete => {
294 let cc: CommandComplete = message.decode()?;
296
297 let rows_affected = cc.rows_affected();
298 logger.increase_rows_affected(rows_affected);
299 r#yield!(Either::Left(PgQueryResult {
300 rows_affected,
301 }));
302 }
303
304 MessageFormat::EmptyQueryResponse => {
305 }
307
308 MessageFormat::PortalSuspended => {}
312
313 MessageFormat::RowDescription => {
314 let (columns, column_names) = self
316 .handle_row_description(Some(message.decode()?), false)
317 .await?;
318
319 metadata = Arc::new(PgStatementMetadata {
320 column_names,
321 columns,
322 parameters: Vec::default(),
323 });
324 }
325
326 MessageFormat::DataRow => {
327 logger.increment_rows_returned();
328
329 let data: DataRow = message.decode()?;
331 let row = PgRow {
332 data,
333 format,
334 metadata: Arc::clone(&metadata),
335 };
336
337 r#yield!(Either::Right(row));
338 }
339
340 MessageFormat::ReadyForQuery => {
341 self.handle_ready_for_query(message)?;
343 break;
344 }
345
346 _ => {
347 return Err(err_protocol!(
348 "execute: unexpected message: {:?}",
349 message.format
350 ));
351 }
352 }
353 }
354
355 Ok(())
356 })
357 }
358}
359
360impl<'c> Executor<'c> for &'c mut PgConnection {
361 type Database = Postgres;
362
363 fn fetch_many<'e, 'q: 'e, E: 'q>(
364 self,
365 mut query: E,
366 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
367 where
368 'c: 'e,
369 E: Execute<'q, Self::Database>,
370 {
371 let sql = query.sql();
372 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
373 let arguments = query.take_arguments();
374 let persistent = query.persistent();
375
376 Box::pin(try_stream! {
377 let s = self.run(sql, arguments, 0, persistent, metadata).await?;
378 pin_mut!(s);
379
380 while let Some(v) = s.try_next().await? {
381 r#yield!(v);
382 }
383
384 Ok(())
385 })
386 }
387
388 fn fetch_optional<'e, 'q: 'e, E: 'q>(
389 self,
390 mut query: E,
391 ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
392 where
393 'c: 'e,
394 E: Execute<'q, Self::Database>,
395 {
396 let sql = query.sql();
397 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
398 let arguments = query.take_arguments();
399 let persistent = query.persistent();
400
401 Box::pin(async move {
402 let s = self.run(sql, arguments, 1, persistent, metadata).await?;
403 pin_mut!(s);
404
405 while let Some(s) = s.try_next().await? {
406 if let Either::Right(r) = s {
407 return Ok(Some(r));
408 }
409 }
410
411 Ok(None)
412 })
413 }
414
415 fn prepare_with<'e, 'q: 'e>(
416 self,
417 sql: &'q str,
418 parameters: &'e [PgTypeInfo],
419 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
420 where
421 'c: 'e,
422 {
423 Box::pin(async move {
424 self.wait_until_ready().await?;
425
426 let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
427
428 Ok(PgStatement {
429 sql: Cow::Borrowed(sql),
430 metadata,
431 })
432 })
433 }
434
435 fn describe<'e, 'q: 'e>(
436 self,
437 sql: &'q str,
438 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
439 where
440 'c: 'e,
441 {
442 Box::pin(async move {
443 self.wait_until_ready().await?;
444
445 let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
446
447 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
448
449 Ok(Describe {
450 columns: metadata.columns.clone(),
451 nullable,
452 parameters: Some(Either::Left(metadata.parameters.clone())),
453 })
454 })
455 }
456}