zero_postgres/tokio/pipeline/
mod.rs1use std::collections::VecDeque;
33
34use crate::pipeline::Expectation;
35use crate::pipeline::Ticket;
36
37use crate::conversion::{FromRow, ToParams};
38use crate::error::{Error, Result};
39use crate::handler::ExtendedHandler;
40use crate::protocol::backend::{
41 BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
42 ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
43};
44use crate::protocol::frontend::{
45 write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
46};
47use crate::state::extended::PreparedStatement;
48use crate::statement::{IntoStatement, StatementRef};
49
50use super::conn::Conn;
51
52pub struct Pipeline<'a> {
56 conn: &'a mut Conn,
57 queue_seq: usize,
59 claim_seq: usize,
61 aborted: bool,
63 column_buffer: Vec<u8>,
65 expectations: VecDeque<Expectation>,
67}
68
69impl<'a> Pipeline<'a> {
70 #[cfg(feature = "lowlevel")]
75 pub fn new(conn: &'a mut Conn) -> Self {
76 Self::new_inner(conn)
77 }
78
79 pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
81 conn.buffer_set.write_buffer.clear();
82 Self {
83 conn,
84 queue_seq: 0,
85 claim_seq: 0,
86 aborted: false,
87 column_buffer: Vec::new(),
88 expectations: VecDeque::new(),
89 }
90 }
91
92 #[cfg(feature = "lowlevel")]
97 pub async fn cleanup(&mut self) {
98 self.cleanup_inner().await;
99 }
100
101 #[cfg(not(feature = "lowlevel"))]
102 pub(crate) async fn cleanup(&mut self) {
103 self.cleanup_inner().await;
104 }
105
106 async fn cleanup_inner(&mut self) {
107 if self.queue_seq == 0 && self.expectations.is_empty() {
109 return;
110 }
111
112 if !self.conn.buffer_set.write_buffer.is_empty()
114 || !self.expectations.iter().any(|e| *e == Expectation::Sync)
115 {
116 let _ = self.sync().await;
117 }
118
119 if self.aborted {
121 while let Some(expectation) = self.expectations.pop_front() {
123 if expectation == Expectation::Sync {
124 let _ = self.consume_ready_for_query().await;
125 }
126 }
127 } else {
128 while let Some(expectation) = self.expectations.pop_front() {
130 let _ = self.drain_expectation(expectation).await;
131 }
132 }
133
134 self.queue_seq = 0;
136 self.claim_seq = 0;
137 self.aborted = false;
138 }
139
140 async fn drain_expectation(&mut self, expectation: Expectation) {
142 let mut handler = crate::handler::DropHandler::new();
143 let _ = match expectation {
144 Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
145 Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
146 Expectation::Sync => self.consume_ready_for_query().await,
147 };
148 }
149
150 pub fn exec<'s, P: ToParams>(
179 &mut self,
180 statement: &'s (impl IntoStatement + ?Sized),
181 params: P,
182 ) -> Result<Ticket<'s>> {
183 let seq = self.queue_seq;
184 self.queue_seq += 1;
185
186 match statement.statement_ref() {
187 StatementRef::Sql(sql) => {
188 self.exec_sql_inner(sql, ¶ms)?;
189 Ok(Ticket { seq, stmt: None })
190 }
191 StatementRef::Prepared(stmt) => {
192 self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, ¶ms)?;
193 Ok(Ticket {
194 seq,
195 stmt: Some(stmt),
196 })
197 }
198 }
199 }
200
201 fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
202 let param_oids = params.natural_oids();
203 let buf = &mut self.conn.buffer_set.write_buffer;
204 write_parse(buf, "", sql, ¶m_oids);
205 write_bind(buf, "", "", params, ¶m_oids)?;
206 write_describe_portal(buf, "");
207 write_execute(buf, "", 0);
208 self.expectations.push_back(Expectation::ParseBindExecute);
209 Ok(())
210 }
211
212 fn exec_prepared_inner<P: ToParams>(
213 &mut self,
214 stmt_name: &str,
215 param_oids: &[u32],
216 params: &P,
217 ) -> Result<()> {
218 let buf = &mut self.conn.buffer_set.write_buffer;
219 write_bind(buf, "", stmt_name, params, param_oids)?;
220 write_execute(buf, "", 0);
222 self.expectations.push_back(Expectation::BindExecute);
223 Ok(())
224 }
225
226 pub async fn flush(&mut self) -> Result<()> {
231 if !self.conn.buffer_set.write_buffer.is_empty() {
232 write_flush(&mut self.conn.buffer_set.write_buffer);
233 self.conn
234 .stream
235 .write_all(&self.conn.buffer_set.write_buffer)
236 .await?;
237 self.conn.stream.flush().await?;
238 self.conn.buffer_set.write_buffer.clear();
239 }
240 Ok(())
241 }
242
243 pub async fn sync(&mut self) -> Result<()> {
248 let result = self.sync_inner().await;
249 if let Err(e) = &result
250 && e.is_connection_broken()
251 {
252 self.conn.is_broken = true;
253 }
254 result
255 }
256
257 async fn sync_inner(&mut self) -> Result<()> {
258 write_sync(&mut self.conn.buffer_set.write_buffer);
259 self.expectations.push_back(Expectation::Sync);
260 self.conn
261 .stream
262 .write_all(&self.conn.buffer_set.write_buffer)
263 .await?;
264 self.conn.stream.flush().await?;
265 self.conn.buffer_set.write_buffer.clear();
266 Ok(())
267 }
268
269 async fn consume_ready_for_query(&mut self) -> Result<()> {
271 loop {
272 self.conn
273 .stream
274 .read_message(&mut self.conn.buffer_set)
275 .await?;
276 let type_byte = self.conn.buffer_set.type_byte;
277
278 if RawMessage::is_async_type(type_byte) {
279 continue;
280 }
281
282 if type_byte == msg_type::ERROR_RESPONSE {
283 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
284 return Err(error.into_error());
285 }
286
287 if type_byte == msg_type::READY_FOR_QUERY {
288 let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
289 self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
290 return Ok(());
291 }
292 }
293 }
294
295 async fn consume_pending_syncs(&mut self) -> Result<()> {
297 while self.expectations.front() == Some(&Expectation::Sync) {
298 self.expectations.pop_front();
299 self.consume_ready_for_query().await?;
300 self.aborted = false;
302 }
303 Ok(())
304 }
305
306 #[cfg(feature = "lowlevel")]
314 pub async fn claim<H: ExtendedHandler>(
315 &mut self,
316 ticket: Ticket<'_>,
317 handler: &mut H,
318 ) -> Result<()> {
319 self.claim_with_handler(ticket, handler).await
320 }
321
322 async fn claim_with_handler<H: ExtendedHandler>(
323 &mut self,
324 ticket: Ticket<'_>,
325 handler: &mut H,
326 ) -> Result<()> {
327 self.check_sequence(ticket.seq)?;
328
329 if !self.conn.buffer_set.write_buffer.is_empty() {
331 self.sync().await?;
332 }
333
334 if self.aborted {
335 self.claim_seq += 1;
336 self.expectations.pop_front();
338 self.consume_pending_syncs().await?;
339 return Err(Error::LibraryBug(
340 "pipeline aborted due to earlier error".into(),
341 ));
342 }
343
344 let expectation = self.expectations.pop_front();
345
346 let result = match expectation {
347 Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
348 Some(Expectation::BindExecute) => {
349 self.claim_bind_exec_inner(handler, ticket.stmt).await
350 }
351 Some(Expectation::Sync) => Err(Error::LibraryBug("unexpected Sync expectation".into())),
352 None => Err(Error::LibraryBug("no expectation in queue".into())),
353 };
354
355 if let Err(e) = &result {
356 if e.is_connection_broken() {
357 self.conn.is_broken = true;
358 }
359 self.aborted = true;
360 }
361 self.claim_seq += 1;
362 self.consume_pending_syncs().await?;
363 result
364 }
365
366 pub async fn claim_collect<T: for<'b> FromRow<'b>>(
370 &mut self,
371 ticket: Ticket<'_>,
372 ) -> Result<Vec<T>> {
373 let mut handler = crate::handler::CollectHandler::<T>::new();
374 self.claim_with_handler(ticket, &mut handler).await?;
375 Ok(handler.into_rows())
376 }
377
378 pub async fn claim_one<T: for<'b> FromRow<'b>>(
382 &mut self,
383 ticket: Ticket<'_>,
384 ) -> Result<Option<T>> {
385 let mut handler = crate::handler::FirstRowHandler::<T>::new();
386 self.claim_with_handler(ticket, &mut handler).await?;
387 Ok(handler.into_row())
388 }
389
390 pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
394 let mut handler = crate::handler::DropHandler::new();
395 self.claim_with_handler(ticket, &mut handler).await
396 }
397
398 fn check_sequence(&self, seq: usize) -> Result<()> {
400 if seq != self.claim_seq {
401 return Err(Error::InvalidUsage(format!(
402 "claim out of order: expected seq {}, got {}",
403 self.claim_seq, seq
404 )));
405 }
406 Ok(())
407 }
408
409 async fn claim_parse_bind_exec_inner<H: ExtendedHandler>(
411 &mut self,
412 handler: &mut H,
413 ) -> Result<()> {
414 self.read_next_message().await?;
416 if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
417 return self.unexpected_message("ParseComplete");
418 }
419 ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
420
421 self.read_next_message().await?;
423 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
424 return self.unexpected_message("BindComplete");
425 }
426 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
427
428 self.claim_rows_inner(handler).await
430 }
431
432 async fn claim_bind_exec_inner<H: ExtendedHandler>(
434 &mut self,
435 handler: &mut H,
436 stmt: Option<&PreparedStatement>,
437 ) -> Result<()> {
438 self.read_next_message().await?;
440 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
441 return self.unexpected_message("BindComplete");
442 }
443 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
444
445 let row_desc = stmt.and_then(|s| s.row_desc_payload());
447
448 self.claim_rows_cached_inner(handler, row_desc).await
450 }
451
452 async fn claim_rows_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
454 self.read_next_message().await?;
456 let has_rows = match self.conn.buffer_set.type_byte {
457 msg_type::ROW_DESCRIPTION => {
458 self.column_buffer.clear();
459 self.column_buffer
460 .extend_from_slice(&self.conn.buffer_set.read_buffer);
461 true
462 }
463 msg_type::NO_DATA => {
464 NoData::parse(&self.conn.buffer_set.read_buffer)?;
465 false
467 }
468 _ => {
469 return Err(Error::LibraryBug(format!(
470 "expected RowDescription or NoData, got '{}'",
471 self.conn.buffer_set.type_byte as char
472 )));
473 }
474 };
475
476 loop {
478 self.read_next_message().await?;
479 let type_byte = self.conn.buffer_set.type_byte;
480
481 match type_byte {
482 msg_type::DATA_ROW => {
483 if !has_rows {
484 return Err(Error::LibraryBug(
485 "received DataRow but no RowDescription".into(),
486 ));
487 }
488 let cols = RowDescription::parse(&self.column_buffer)?;
489 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
490 handler.row(cols, row)?;
491 }
492 msg_type::COMMAND_COMPLETE => {
493 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
494 handler.result_end(cmd)?;
495 return Ok(());
496 }
497 msg_type::EMPTY_QUERY_RESPONSE => {
498 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
499 return Ok(());
500 }
501 _ => {
502 return Err(Error::LibraryBug(format!(
503 "unexpected message type in pipeline claim: '{}'",
504 type_byte as char
505 )));
506 }
507 }
508 }
509 }
510
511 async fn claim_rows_cached_inner<H: ExtendedHandler>(
513 &mut self,
514 handler: &mut H,
515 row_desc: Option<&[u8]>,
516 ) -> Result<()> {
517 loop {
519 self.read_next_message().await?;
520 let type_byte = self.conn.buffer_set.type_byte;
521
522 match type_byte {
523 msg_type::DATA_ROW => {
524 let row_desc = row_desc.ok_or_else(|| {
525 Error::LibraryBug("received DataRow but no RowDescription cached".into())
526 })?;
527 let cols = RowDescription::parse(row_desc)?;
528 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
529 handler.row(cols, row)?;
530 }
531 msg_type::COMMAND_COMPLETE => {
532 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
533 handler.result_end(cmd)?;
534 return Ok(());
535 }
536 msg_type::EMPTY_QUERY_RESPONSE => {
537 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
538 return Ok(());
539 }
540 _ => {
541 return Err(Error::LibraryBug(format!(
542 "unexpected message type in pipeline claim: '{}'",
543 type_byte as char
544 )));
545 }
546 }
547 }
548 }
549
550 async fn read_next_message(&mut self) -> Result<()> {
552 loop {
553 self.conn
554 .stream
555 .read_message(&mut self.conn.buffer_set)
556 .await?;
557 let type_byte = self.conn.buffer_set.type_byte;
558
559 if RawMessage::is_async_type(type_byte) {
561 continue;
562 }
563
564 if type_byte == msg_type::ERROR_RESPONSE {
566 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
567 return Err(error.into_error());
568 }
569
570 return Ok(());
571 }
572 }
573
574 fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
576 Err(Error::LibraryBug(format!(
577 "expected {}, got '{}'",
578 expected, self.conn.buffer_set.type_byte as char
579 )))
580 }
581
582 pub fn pending_count(&self) -> usize {
584 self.queue_seq - self.claim_seq
585 }
586
587 pub fn is_aborted(&self) -> bool {
589 self.aborted
590 }
591}