zero_postgres/tokio/pipeline/
mod.rs1use std::collections::VecDeque;
33
34use crate::pipeline::Expectation;
35use crate::pipeline::Ticket;
36
37use crate::conversion::{FromRow, ToParams};
38use crate::error::{Error, Result};
39use crate::handler::BinaryHandler;
40use crate::protocol::backend::{
41 BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
42 ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
43};
44use crate::protocol::frontend::{
45 write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
46};
47use crate::state::extended::PreparedStatement;
48use crate::statement::IntoStatement;
49
50use super::conn::Conn;
51
52pub struct Pipeline<'a> {
56 conn: &'a mut Conn,
57 queue_seq: usize,
59 claim_seq: usize,
61 aborted: bool,
63 column_buffer: Vec<u8>,
65 expectations: VecDeque<Expectation>,
67}
68
69impl<'a> Pipeline<'a> {
70 #[cfg(feature = "lowlevel")]
75 pub fn new(conn: &'a mut Conn) -> Self {
76 Self::new_inner(conn)
77 }
78
79 pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
81 conn.buffer_set.write_buffer.clear();
82 Self {
83 conn,
84 queue_seq: 0,
85 claim_seq: 0,
86 aborted: false,
87 column_buffer: Vec::new(),
88 expectations: VecDeque::new(),
89 }
90 }
91
92 #[cfg(feature = "lowlevel")]
97 pub async fn cleanup(&mut self) {
98 self.cleanup_inner().await;
99 }
100
101 #[cfg(not(feature = "lowlevel"))]
102 pub(crate) async fn cleanup(&mut self) {
103 self.cleanup_inner().await;
104 }
105
106 async fn cleanup_inner(&mut self) {
107 if self.queue_seq == 0 {
109 return;
110 }
111
112 if !self.conn.buffer_set.write_buffer.is_empty() {
114 let _ = self.sync().await;
115 } else if !self.expectations.iter().any(|e| *e == Expectation::Sync) {
116 let _ = self.sync().await;
118 }
119
120 if self.aborted {
122 while let Some(expectation) = self.expectations.pop_front() {
124 if expectation == Expectation::Sync {
125 let _ = self.consume_ready_for_query().await;
126 }
127 }
128 } else {
129 while let Some(expectation) = self.expectations.pop_front() {
131 let _ = self.drain_expectation(expectation).await;
132 }
133 }
134
135 self.queue_seq = 0;
137 self.claim_seq = 0;
138 self.aborted = false;
139 }
140
141 async fn drain_expectation(&mut self, expectation: Expectation) {
143 let mut handler = crate::handler::DropHandler::new();
144 let _ = match expectation {
145 Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
146 Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
147 Expectation::Sync => self.consume_ready_for_query().await,
148 };
149 }
150
151 pub fn exec<'s, P: ToParams>(
180 &mut self,
181 statement: &'s (impl IntoStatement + ?Sized),
182 params: P,
183 ) -> Result<Ticket<'s>> {
184 let seq = self.queue_seq;
185 self.queue_seq += 1;
186
187 if statement.needs_parse() {
188 self.exec_sql_inner(statement.as_sql().unwrap(), ¶ms)?;
189 Ok(Ticket { seq, stmt: None })
190 } else {
191 let stmt = statement.as_prepared().unwrap();
192 self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, ¶ms)?;
193 Ok(Ticket {
194 seq,
195 stmt: Some(stmt),
196 })
197 }
198 }
199
200 fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
201 let param_oids = params.natural_oids();
202 let buf = &mut self.conn.buffer_set.write_buffer;
203 write_parse(buf, "", sql, ¶m_oids);
204 write_bind(buf, "", "", params, ¶m_oids)?;
205 write_describe_portal(buf, "");
206 write_execute(buf, "", 0);
207 self.expectations.push_back(Expectation::ParseBindExecute);
208 Ok(())
209 }
210
211 fn exec_prepared_inner<P: ToParams>(
212 &mut self,
213 stmt_name: &str,
214 param_oids: &[u32],
215 params: &P,
216 ) -> Result<()> {
217 let buf = &mut self.conn.buffer_set.write_buffer;
218 write_bind(buf, "", stmt_name, params, param_oids)?;
219 write_execute(buf, "", 0);
221 self.expectations.push_back(Expectation::BindExecute);
222 Ok(())
223 }
224
225 pub async fn flush(&mut self) -> Result<()> {
230 if !self.conn.buffer_set.write_buffer.is_empty() {
231 write_flush(&mut self.conn.buffer_set.write_buffer);
232 self.conn
233 .stream
234 .write_all(&self.conn.buffer_set.write_buffer)
235 .await?;
236 self.conn.stream.flush().await?;
237 self.conn.buffer_set.write_buffer.clear();
238 }
239 Ok(())
240 }
241
242 pub async fn sync(&mut self) -> Result<()> {
247 let result = self.sync_inner().await;
248 if let Err(e) = &result
249 && e.is_connection_broken()
250 {
251 self.conn.is_broken = true;
252 }
253 result
254 }
255
256 async fn sync_inner(&mut self) -> Result<()> {
257 write_sync(&mut self.conn.buffer_set.write_buffer);
258 self.expectations.push_back(Expectation::Sync);
259 self.conn
260 .stream
261 .write_all(&self.conn.buffer_set.write_buffer)
262 .await?;
263 self.conn.stream.flush().await?;
264 self.conn.buffer_set.write_buffer.clear();
265 Ok(())
266 }
267
268 async fn consume_ready_for_query(&mut self) -> Result<()> {
270 loop {
271 self.conn
272 .stream
273 .read_message(&mut self.conn.buffer_set)
274 .await?;
275 let type_byte = self.conn.buffer_set.type_byte;
276
277 if RawMessage::is_async_type(type_byte) {
278 continue;
279 }
280
281 if type_byte == msg_type::ERROR_RESPONSE {
282 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
283 return Err(error.into_error());
284 }
285
286 if type_byte == msg_type::READY_FOR_QUERY {
287 let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
288 self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
289 return Ok(());
290 }
291 }
292 }
293
294 async fn consume_pending_syncs(&mut self) -> Result<()> {
296 while self.expectations.front() == Some(&Expectation::Sync) {
297 self.expectations.pop_front();
298 self.consume_ready_for_query().await?;
299 self.aborted = false;
301 }
302 Ok(())
303 }
304
305 #[cfg(feature = "lowlevel")]
313 pub async fn claim<H: BinaryHandler>(
314 &mut self,
315 ticket: Ticket<'_>,
316 handler: &mut H,
317 ) -> Result<()> {
318 self.claim_with_handler(ticket, handler).await
319 }
320
321 async fn claim_with_handler<H: BinaryHandler>(
322 &mut self,
323 ticket: Ticket<'_>,
324 handler: &mut H,
325 ) -> Result<()> {
326 self.check_sequence(ticket.seq)?;
327
328 if !self.conn.buffer_set.write_buffer.is_empty() {
330 self.sync().await?;
331 }
332
333 if self.aborted {
334 self.claim_seq += 1;
335 self.expectations.pop_front();
337 self.consume_pending_syncs().await?;
338 return Err(Error::Protocol(
339 "pipeline aborted due to earlier error".into(),
340 ));
341 }
342
343 let expectation = self.expectations.pop_front();
344
345 let result = match expectation {
346 Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
347 Some(Expectation::BindExecute) => {
348 self.claim_bind_exec_inner(handler, ticket.stmt).await
349 }
350 Some(Expectation::Sync) => Err(Error::Protocol("unexpected Sync expectation".into())),
351 None => Err(Error::Protocol("no expectation in queue".into())),
352 };
353
354 if let Err(e) = &result {
355 if e.is_connection_broken() {
356 self.conn.is_broken = true;
357 }
358 self.aborted = true;
359 }
360 self.claim_seq += 1;
361 self.consume_pending_syncs().await?;
362 result
363 }
364
365 pub async fn claim_collect<T: for<'b> FromRow<'b>>(
369 &mut self,
370 ticket: Ticket<'_>,
371 ) -> Result<Vec<T>> {
372 let mut handler = crate::handler::CollectHandler::<T>::new();
373 self.claim_with_handler(ticket, &mut handler).await?;
374 Ok(handler.into_rows())
375 }
376
377 pub async fn claim_one<T: for<'b> FromRow<'b>>(
381 &mut self,
382 ticket: Ticket<'_>,
383 ) -> Result<Option<T>> {
384 let mut handler = crate::handler::FirstRowHandler::<T>::new();
385 self.claim_with_handler(ticket, &mut handler).await?;
386 Ok(handler.into_row())
387 }
388
389 pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
393 let mut handler = crate::handler::DropHandler::new();
394 self.claim_with_handler(ticket, &mut handler).await
395 }
396
397 fn check_sequence(&self, seq: usize) -> Result<()> {
399 if seq != self.claim_seq {
400 return Err(Error::InvalidUsage(format!(
401 "claim out of order: expected seq {}, got {}",
402 self.claim_seq, seq
403 )));
404 }
405 Ok(())
406 }
407
408 async fn claim_parse_bind_exec_inner<H: BinaryHandler>(
410 &mut self,
411 handler: &mut H,
412 ) -> Result<()> {
413 self.read_next_message().await?;
415 if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
416 return self.unexpected_message("ParseComplete");
417 }
418 ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
419
420 self.read_next_message().await?;
422 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
423 return self.unexpected_message("BindComplete");
424 }
425 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
426
427 self.claim_rows_inner(handler).await
429 }
430
431 async fn claim_bind_exec_inner<H: BinaryHandler>(
433 &mut self,
434 handler: &mut H,
435 stmt: Option<&PreparedStatement>,
436 ) -> Result<()> {
437 self.read_next_message().await?;
439 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
440 return self.unexpected_message("BindComplete");
441 }
442 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
443
444 let row_desc = stmt.and_then(|s| s.row_desc_payload());
446
447 self.claim_rows_cached_inner(handler, row_desc).await
449 }
450
451 async fn claim_rows_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
453 self.read_next_message().await?;
455 let has_rows = match self.conn.buffer_set.type_byte {
456 msg_type::ROW_DESCRIPTION => {
457 self.column_buffer.clear();
458 self.column_buffer
459 .extend_from_slice(&self.conn.buffer_set.read_buffer);
460 true
461 }
462 msg_type::NO_DATA => {
463 NoData::parse(&self.conn.buffer_set.read_buffer)?;
464 false
466 }
467 _ => {
468 return Err(Error::Protocol(format!(
469 "expected RowDescription or NoData, got '{}'",
470 self.conn.buffer_set.type_byte as char
471 )));
472 }
473 };
474
475 loop {
477 self.read_next_message().await?;
478 let type_byte = self.conn.buffer_set.type_byte;
479
480 match type_byte {
481 msg_type::DATA_ROW => {
482 if !has_rows {
483 return Err(Error::Protocol(
484 "received DataRow but no RowDescription".into(),
485 ));
486 }
487 let cols = RowDescription::parse(&self.column_buffer)?;
488 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
489 handler.row(cols, row)?;
490 }
491 msg_type::COMMAND_COMPLETE => {
492 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
493 handler.result_end(cmd)?;
494 return Ok(());
495 }
496 msg_type::EMPTY_QUERY_RESPONSE => {
497 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
498 return Ok(());
499 }
500 _ => {
501 return Err(Error::Protocol(format!(
502 "unexpected message type in pipeline claim: '{}'",
503 type_byte as char
504 )));
505 }
506 }
507 }
508 }
509
510 async fn claim_rows_cached_inner<H: BinaryHandler>(
512 &mut self,
513 handler: &mut H,
514 row_desc: Option<&[u8]>,
515 ) -> Result<()> {
516 loop {
518 self.read_next_message().await?;
519 let type_byte = self.conn.buffer_set.type_byte;
520
521 match type_byte {
522 msg_type::DATA_ROW => {
523 let row_desc = row_desc.ok_or_else(|| {
524 Error::Protocol("received DataRow but no RowDescription cached".into())
525 })?;
526 let cols = RowDescription::parse(row_desc)?;
527 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
528 handler.row(cols, row)?;
529 }
530 msg_type::COMMAND_COMPLETE => {
531 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
532 handler.result_end(cmd)?;
533 return Ok(());
534 }
535 msg_type::EMPTY_QUERY_RESPONSE => {
536 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
537 return Ok(());
538 }
539 _ => {
540 return Err(Error::Protocol(format!(
541 "unexpected message type in pipeline claim: '{}'",
542 type_byte as char
543 )));
544 }
545 }
546 }
547 }
548
549 async fn read_next_message(&mut self) -> Result<()> {
551 loop {
552 self.conn
553 .stream
554 .read_message(&mut self.conn.buffer_set)
555 .await?;
556 let type_byte = self.conn.buffer_set.type_byte;
557
558 if RawMessage::is_async_type(type_byte) {
560 continue;
561 }
562
563 if type_byte == msg_type::ERROR_RESPONSE {
565 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
566 return Err(error.into_error());
567 }
568
569 return Ok(());
570 }
571 }
572
573 fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
575 Err(Error::Protocol(format!(
576 "expected {}, got '{}'",
577 expected, self.conn.buffer_set.type_byte as char
578 )))
579 }
580
581 pub fn pending_count(&self) -> usize {
583 self.queue_seq - self.claim_seq
584 }
585
586 pub fn is_aborted(&self) -> bool {
588 self.aborted
589 }
590}