zero_postgres/tokio/pipeline/
mod.rs1use crate::pipeline::Expectation;
33use crate::pipeline::Ticket;
34
35use crate::conversion::{FromRow, ToParams};
36use crate::error::{Error, Result};
37use crate::handler::BinaryHandler;
38use crate::protocol::backend::{
39 BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
40 ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
41};
42use crate::protocol::frontend::{
43 write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
44};
45use crate::state::extended::PreparedStatement;
46use crate::statement::IntoStatement;
47
48use super::conn::Conn;
49
50pub struct Pipeline<'a> {
54 conn: &'a mut Conn,
55 queue_seq: usize,
57 claim_seq: usize,
59 aborted: bool,
61 column_buffer: Vec<u8>,
63 expectations: Vec<Expectation>,
65}
66
67impl<'a> Pipeline<'a> {
68 #[cfg(feature = "lowlevel")]
73 pub fn new(conn: &'a mut Conn) -> Self {
74 Self::new_inner(conn)
75 }
76
77 pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
79 conn.buffer_set.write_buffer.clear();
80 Self {
81 conn,
82 queue_seq: 0,
83 claim_seq: 0,
84 aborted: false,
85 column_buffer: Vec::new(),
86 expectations: Vec::new(),
87 }
88 }
89
90 #[cfg(feature = "lowlevel")]
95 pub async fn cleanup(&mut self) {
96 self.cleanup_inner().await;
97 }
98
99 #[cfg(not(feature = "lowlevel"))]
100 pub(crate) async fn cleanup(&mut self) {
101 self.cleanup_inner().await;
102 }
103
104 async fn cleanup_inner(&mut self) {
105 if self.queue_seq == self.claim_seq {
106 return;
107 }
108
109 if !self.conn.buffer_set.write_buffer.is_empty() {
111 let _ = self.sync().await;
112 }
113
114 while self.claim_seq < self.queue_seq {
116 let _ = self.drain_one().await;
117 self.claim_seq += 1;
118 }
119
120 let _ = self.finish().await;
122 }
123
124 async fn drain_one(&mut self) {
126 let Some(expectation) = self.expectations.get(self.claim_seq).copied() else {
127 return;
128 };
129 let mut handler = crate::handler::DropHandler::new();
130
131 let _ = match expectation {
132 Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
133 Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
136 };
137 }
138
139 pub fn exec<'s, P: ToParams>(
168 &mut self,
169 statement: &'s (impl IntoStatement + ?Sized),
170 params: P,
171 ) -> Result<Ticket<'s>> {
172 let seq = self.queue_seq;
173 self.queue_seq += 1;
174
175 if statement.needs_parse() {
176 self.exec_sql_inner(statement.as_sql().unwrap(), ¶ms)?;
177 Ok(Ticket { seq, stmt: None })
178 } else {
179 let stmt = statement.as_prepared().unwrap();
180 self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, ¶ms)?;
181 Ok(Ticket {
182 seq,
183 stmt: Some(stmt),
184 })
185 }
186 }
187
188 fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
189 let param_oids = params.natural_oids();
190 let buf = &mut self.conn.buffer_set.write_buffer;
191 write_parse(buf, "", sql, ¶m_oids);
192 write_bind(buf, "", "", params, ¶m_oids)?;
193 write_describe_portal(buf, "");
194 write_execute(buf, "", 0);
195 self.expectations.push(Expectation::ParseBindExecute);
196 Ok(())
197 }
198
199 fn exec_prepared_inner<P: ToParams>(
200 &mut self,
201 stmt_name: &str,
202 param_oids: &[u32],
203 params: &P,
204 ) -> Result<()> {
205 let buf = &mut self.conn.buffer_set.write_buffer;
206 write_bind(buf, "", stmt_name, params, param_oids)?;
207 write_execute(buf, "", 0);
209 self.expectations.push(Expectation::BindExecute);
210 Ok(())
211 }
212
213 pub async fn flush(&mut self) -> Result<()> {
218 if !self.conn.buffer_set.write_buffer.is_empty() {
219 write_flush(&mut self.conn.buffer_set.write_buffer);
220 self.conn
221 .stream
222 .write_all(&self.conn.buffer_set.write_buffer)
223 .await?;
224 self.conn.stream.flush().await?;
225 self.conn.buffer_set.write_buffer.clear();
226 }
227 Ok(())
228 }
229
230 pub async fn sync(&mut self) -> Result<()> {
236 let result = self.sync_inner().await;
237 if let Err(e) = &result
238 && e.is_connection_broken()
239 {
240 self.conn.is_broken = true;
241 }
242 result
243 }
244
245 async fn sync_inner(&mut self) -> Result<()> {
246 write_sync(&mut self.conn.buffer_set.write_buffer);
247 self.conn
248 .stream
249 .write_all(&self.conn.buffer_set.write_buffer)
250 .await?;
251 self.conn.stream.flush().await?;
252 self.conn.buffer_set.write_buffer.clear();
253 Ok(())
254 }
255
256 async fn finish(&mut self) -> Result<()> {
258 loop {
260 self.conn
261 .stream
262 .read_message(&mut self.conn.buffer_set)
263 .await?;
264 let type_byte = self.conn.buffer_set.type_byte;
265
266 if RawMessage::is_async_type(type_byte) {
268 continue;
269 }
270
271 if type_byte == msg_type::ERROR_RESPONSE {
273 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
274 return Err(error.into_error());
275 }
276
277 if type_byte == msg_type::READY_FOR_QUERY {
278 let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
279 self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
280 self.queue_seq = 0;
282 self.claim_seq = 0;
283 self.expectations.clear();
284 self.aborted = false;
285 return Ok(());
286 }
287 }
288 }
289
290 #[cfg(feature = "lowlevel")]
298 pub async fn claim<H: BinaryHandler>(
299 &mut self,
300 ticket: Ticket<'_>,
301 handler: &mut H,
302 ) -> Result<()> {
303 self.claim_with_handler(ticket, handler).await
304 }
305
306 async fn claim_with_handler<H: BinaryHandler>(
307 &mut self,
308 ticket: Ticket<'_>,
309 handler: &mut H,
310 ) -> Result<()> {
311 self.check_sequence(ticket.seq)?;
312 self.flush().await?;
313
314 if self.aborted {
315 self.claim_seq += 1;
316 self.maybe_finish().await?;
317 return Err(Error::Protocol(
318 "pipeline aborted due to earlier error".into(),
319 ));
320 }
321
322 let expectation = self.expectations.get(ticket.seq).copied();
323
324 let result = match expectation {
325 Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
326 Some(Expectation::BindExecute) => {
327 self.claim_bind_exec_inner(handler, ticket.stmt).await
328 }
329 None => Err(Error::Protocol("unexpected expectation type".into())),
330 };
331
332 if let Err(e) = &result {
333 if e.is_connection_broken() {
334 self.conn.is_broken = true;
335 }
336 self.aborted = true;
337 }
338 self.claim_seq += 1;
339 self.maybe_finish().await?;
340 result
341 }
342
343 pub async fn claim_collect<T: for<'b> FromRow<'b>>(
347 &mut self,
348 ticket: Ticket<'_>,
349 ) -> Result<Vec<T>> {
350 let mut handler = crate::handler::CollectHandler::<T>::new();
351 self.claim_with_handler(ticket, &mut handler).await?;
352 Ok(handler.into_rows())
353 }
354
355 pub async fn claim_one<T: for<'b> FromRow<'b>>(
359 &mut self,
360 ticket: Ticket<'_>,
361 ) -> Result<Option<T>> {
362 let mut handler = crate::handler::FirstRowHandler::<T>::new();
363 self.claim_with_handler(ticket, &mut handler).await?;
364 Ok(handler.into_row())
365 }
366
367 pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
371 let mut handler = crate::handler::DropHandler::new();
372 self.claim_with_handler(ticket, &mut handler).await
373 }
374
375 fn check_sequence(&self, seq: usize) -> Result<()> {
377 if seq != self.claim_seq {
378 return Err(Error::InvalidUsage(format!(
379 "claim out of order: expected seq {}, got {}",
380 self.claim_seq, seq
381 )));
382 }
383 Ok(())
384 }
385
386 async fn maybe_finish(&mut self) -> Result<()> {
388 if self.claim_seq == self.queue_seq {
389 self.finish().await?;
390 }
391 Ok(())
392 }
393
394 async fn claim_parse_bind_exec_inner<H: BinaryHandler>(
396 &mut self,
397 handler: &mut H,
398 ) -> Result<()> {
399 self.read_next_message().await?;
401 if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
402 return self.unexpected_message("ParseComplete");
403 }
404 ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
405
406 self.read_next_message().await?;
408 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
409 return self.unexpected_message("BindComplete");
410 }
411 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
412
413 self.claim_rows_inner(handler).await
415 }
416
417 async fn claim_bind_exec_inner<H: BinaryHandler>(
419 &mut self,
420 handler: &mut H,
421 stmt: Option<&PreparedStatement>,
422 ) -> Result<()> {
423 self.read_next_message().await?;
425 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
426 return self.unexpected_message("BindComplete");
427 }
428 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
429
430 let row_desc = stmt.and_then(|s| s.row_desc_payload());
432
433 self.claim_rows_cached_inner(handler, row_desc).await
435 }
436
437 async fn claim_rows_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
439 self.read_next_message().await?;
441 let has_rows = match self.conn.buffer_set.type_byte {
442 msg_type::ROW_DESCRIPTION => {
443 self.column_buffer.clear();
444 self.column_buffer
445 .extend_from_slice(&self.conn.buffer_set.read_buffer);
446 true
447 }
448 msg_type::NO_DATA => {
449 NoData::parse(&self.conn.buffer_set.read_buffer)?;
450 false
452 }
453 _ => {
454 return Err(Error::Protocol(format!(
455 "expected RowDescription or NoData, got '{}'",
456 self.conn.buffer_set.type_byte as char
457 )));
458 }
459 };
460
461 loop {
463 self.read_next_message().await?;
464 let type_byte = self.conn.buffer_set.type_byte;
465
466 match type_byte {
467 msg_type::DATA_ROW => {
468 if !has_rows {
469 return Err(Error::Protocol(
470 "received DataRow but no RowDescription".into(),
471 ));
472 }
473 let cols = RowDescription::parse(&self.column_buffer)?;
474 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
475 handler.row(cols, row)?;
476 }
477 msg_type::COMMAND_COMPLETE => {
478 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
479 handler.result_end(cmd)?;
480 return Ok(());
481 }
482 msg_type::EMPTY_QUERY_RESPONSE => {
483 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
484 return Ok(());
485 }
486 _ => {
487 return Err(Error::Protocol(format!(
488 "unexpected message type in pipeline claim: '{}'",
489 type_byte as char
490 )));
491 }
492 }
493 }
494 }
495
496 async fn claim_rows_cached_inner<H: BinaryHandler>(
498 &mut self,
499 handler: &mut H,
500 row_desc: Option<&[u8]>,
501 ) -> Result<()> {
502 loop {
504 self.read_next_message().await?;
505 let type_byte = self.conn.buffer_set.type_byte;
506
507 match type_byte {
508 msg_type::DATA_ROW => {
509 let row_desc = row_desc.ok_or_else(|| {
510 Error::Protocol("received DataRow but no RowDescription cached".into())
511 })?;
512 let cols = RowDescription::parse(row_desc)?;
513 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
514 handler.row(cols, row)?;
515 }
516 msg_type::COMMAND_COMPLETE => {
517 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
518 handler.result_end(cmd)?;
519 return Ok(());
520 }
521 msg_type::EMPTY_QUERY_RESPONSE => {
522 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
523 return Ok(());
524 }
525 _ => {
526 return Err(Error::Protocol(format!(
527 "unexpected message type in pipeline claim: '{}'",
528 type_byte as char
529 )));
530 }
531 }
532 }
533 }
534
535 async fn read_next_message(&mut self) -> Result<()> {
537 loop {
538 self.conn
539 .stream
540 .read_message(&mut self.conn.buffer_set)
541 .await?;
542 let type_byte = self.conn.buffer_set.type_byte;
543
544 if RawMessage::is_async_type(type_byte) {
546 continue;
547 }
548
549 if type_byte == msg_type::ERROR_RESPONSE {
551 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
552 return Err(error.into_error());
553 }
554
555 return Ok(());
556 }
557 }
558
559 fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
561 Err(Error::Protocol(format!(
562 "expected {}, got '{}'",
563 expected, self.conn.buffer_set.type_byte as char
564 )))
565 }
566
567 pub fn pending_count(&self) -> usize {
569 self.queue_seq - self.claim_seq
570 }
571
572 pub fn is_aborted(&self) -> bool {
574 self.aborted
575 }
576}