zero_postgres/sync/pipeline/
mod.rs1use crate::pipeline::Expectation;
33use crate::pipeline::Ticket;
34
35use crate::conversion::{FromRow, ToParams};
36use crate::error::{Error, Result};
37use crate::handler::BinaryHandler;
38use crate::protocol::backend::{
39 BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
40 ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
41};
42use crate::protocol::frontend::{
43 write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
44};
45use crate::state::extended::PreparedStatement;
46use crate::statement::IntoStatement;
47
48use super::conn::Conn;
49
50pub struct Pipeline<'a> {
54 conn: &'a mut Conn,
55 queue_seq: usize,
57 claim_seq: usize,
59 aborted: bool,
61 column_buffer: Vec<u8>,
63 expectations: Vec<Expectation>,
65}
66
67impl<'a> Pipeline<'a> {
68 #[cfg(feature = "lowlevel")]
73 pub fn new(conn: &'a mut Conn) -> Self {
74 Self::new_inner(conn)
75 }
76
77 pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
79 conn.buffer_set.write_buffer.clear();
80 Self {
81 conn,
82 queue_seq: 0,
83 claim_seq: 0,
84 aborted: false,
85 column_buffer: Vec::new(),
86 expectations: Vec::new(),
87 }
88 }
89
90 #[cfg(feature = "lowlevel")]
95 pub fn cleanup(&mut self) {
96 self.cleanup_inner();
97 }
98
99 #[cfg(not(feature = "lowlevel"))]
100 pub(crate) fn cleanup(&mut self) {
101 self.cleanup_inner();
102 }
103
104 fn cleanup_inner(&mut self) {
105 if self.queue_seq == self.claim_seq {
106 return;
107 }
108
109 if !self.conn.buffer_set.write_buffer.is_empty() {
111 let _ = self.sync();
112 }
113
114 while self.claim_seq < self.queue_seq {
116 let _ = self.drain_one();
117 self.claim_seq += 1;
118 }
119
120 let _ = self.finish();
122 }
123
124 fn drain_one(&mut self) {
126 let Some(expectation) = self.expectations.get(self.claim_seq).copied() else {
127 return;
128 };
129 let mut handler = crate::handler::DropHandler::new();
130
131 let _ = match expectation {
132 Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler),
133 Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None),
136 };
137 }
138
139 pub fn exec<'s, P: ToParams>(
168 &mut self,
169 statement: &'s (impl IntoStatement + ?Sized),
170 params: P,
171 ) -> Result<Ticket<'s>> {
172 let seq = self.queue_seq;
173 self.queue_seq += 1;
174
175 if statement.needs_parse() {
176 self.exec_sql_inner(statement.as_sql().unwrap(), ¶ms)?;
177 Ok(Ticket { seq, stmt: None })
178 } else {
179 let stmt = statement.as_prepared().unwrap();
180 self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, ¶ms)?;
181 Ok(Ticket {
182 seq,
183 stmt: Some(stmt),
184 })
185 }
186 }
187
188 fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
189 let param_oids = params.natural_oids();
190 let buf = &mut self.conn.buffer_set.write_buffer;
191 write_parse(buf, "", sql, ¶m_oids);
192 write_bind(buf, "", "", params, ¶m_oids)?;
193 write_describe_portal(buf, "");
194 write_execute(buf, "", 0);
195 self.expectations.push(Expectation::ParseBindExecute);
196 Ok(())
197 }
198
199 fn exec_prepared_inner<P: ToParams>(
200 &mut self,
201 stmt_name: &str,
202 param_oids: &[u32],
203 params: &P,
204 ) -> Result<()> {
205 let buf = &mut self.conn.buffer_set.write_buffer;
206 write_bind(buf, "", stmt_name, params, param_oids)?;
207 write_execute(buf, "", 0);
209 self.expectations.push(Expectation::BindExecute);
210 Ok(())
211 }
212
213 pub fn flush(&mut self) -> Result<()> {
218 if !self.conn.buffer_set.write_buffer.is_empty() {
219 write_flush(&mut self.conn.buffer_set.write_buffer);
220 self.conn
221 .stream
222 .write_all(&self.conn.buffer_set.write_buffer)?;
223 self.conn.stream.flush()?;
224 self.conn.buffer_set.write_buffer.clear();
225 }
226 Ok(())
227 }
228
229 pub fn sync(&mut self) -> Result<()> {
235 let result = self.sync_inner();
236 if let Err(e) = &result
237 && e.is_connection_broken()
238 {
239 self.conn.is_broken = true;
240 }
241 result
242 }
243
244 fn sync_inner(&mut self) -> Result<()> {
245 write_sync(&mut self.conn.buffer_set.write_buffer);
246 self.conn
247 .stream
248 .write_all(&self.conn.buffer_set.write_buffer)?;
249 self.conn.stream.flush()?;
250 self.conn.buffer_set.write_buffer.clear();
251 Ok(())
252 }
253
254 fn finish(&mut self) -> Result<()> {
256 loop {
258 self.conn.stream.read_message(&mut self.conn.buffer_set)?;
259 let type_byte = self.conn.buffer_set.type_byte;
260
261 if RawMessage::is_async_type(type_byte) {
263 continue;
264 }
265
266 if type_byte == msg_type::ERROR_RESPONSE {
268 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
269 return Err(error.into_error());
270 }
271
272 if type_byte == msg_type::READY_FOR_QUERY {
273 let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
274 self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
275 self.queue_seq = 0;
277 self.claim_seq = 0;
278 self.expectations.clear();
279 self.aborted = false;
280 return Ok(());
281 }
282 }
283 }
284
285 #[cfg(feature = "lowlevel")]
293 pub fn claim<H: BinaryHandler>(&mut self, ticket: Ticket<'_>, handler: &mut H) -> Result<()> {
294 self.claim_with_handler(ticket, handler)
295 }
296
297 fn claim_with_handler<H: BinaryHandler>(
298 &mut self,
299 ticket: Ticket<'_>,
300 handler: &mut H,
301 ) -> Result<()> {
302 self.check_sequence(ticket.seq)?;
303 self.flush()?;
304
305 if self.aborted {
306 self.claim_seq += 1;
307 self.maybe_finish()?;
308 return Err(Error::Protocol(
309 "pipeline aborted due to earlier error".into(),
310 ));
311 }
312
313 let expectation = self.expectations.get(ticket.seq).copied();
314
315 let result = match expectation {
316 Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler),
317 Some(Expectation::BindExecute) => self.claim_bind_exec_inner(handler, ticket.stmt),
318 None => Err(Error::Protocol("unexpected expectation type".into())),
319 };
320
321 if let Err(e) = &result {
322 if e.is_connection_broken() {
323 self.conn.is_broken = true;
324 }
325 self.aborted = true;
326 }
327 self.claim_seq += 1;
328 self.maybe_finish()?;
329 result
330 }
331
332 pub fn claim_collect<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Vec<T>> {
336 let mut handler = crate::handler::CollectHandler::<T>::new();
337 self.claim_with_handler(ticket, &mut handler)?;
338 Ok(handler.into_rows())
339 }
340
341 pub fn claim_one<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Option<T>> {
345 let mut handler = crate::handler::FirstRowHandler::<T>::new();
346 self.claim_with_handler(ticket, &mut handler)?;
347 Ok(handler.into_row())
348 }
349
350 pub fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
354 let mut handler = crate::handler::DropHandler::new();
355 self.claim_with_handler(ticket, &mut handler)
356 }
357
358 fn check_sequence(&self, seq: usize) -> Result<()> {
360 if seq != self.claim_seq {
361 return Err(Error::InvalidUsage(format!(
362 "claim out of order: expected seq {}, got {}",
363 self.claim_seq, seq
364 )));
365 }
366 Ok(())
367 }
368
369 fn maybe_finish(&mut self) -> Result<()> {
371 if self.claim_seq == self.queue_seq {
372 self.finish()?;
373 }
374 Ok(())
375 }
376
377 fn claim_parse_bind_exec_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
379 self.read_next_message()?;
381 if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
382 return self.unexpected_message("ParseComplete");
383 }
384 ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
385
386 self.read_next_message()?;
388 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
389 return self.unexpected_message("BindComplete");
390 }
391 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
392
393 self.claim_rows_inner(handler)
395 }
396
397 fn claim_bind_exec_inner<H: BinaryHandler>(
399 &mut self,
400 handler: &mut H,
401 stmt: Option<&PreparedStatement>,
402 ) -> Result<()> {
403 self.read_next_message()?;
405 if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
406 return self.unexpected_message("BindComplete");
407 }
408 BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
409
410 let row_desc = stmt.and_then(|s| s.row_desc_payload());
412
413 self.claim_rows_cached_inner(handler, row_desc)
415 }
416
417 fn claim_rows_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
419 self.read_next_message()?;
421 let has_rows = match self.conn.buffer_set.type_byte {
422 msg_type::ROW_DESCRIPTION => {
423 self.column_buffer.clear();
424 self.column_buffer
425 .extend_from_slice(&self.conn.buffer_set.read_buffer);
426 true
427 }
428 msg_type::NO_DATA => {
429 NoData::parse(&self.conn.buffer_set.read_buffer)?;
430 false
432 }
433 _ => {
434 return Err(Error::Protocol(format!(
435 "expected RowDescription or NoData, got '{}'",
436 self.conn.buffer_set.type_byte as char
437 )));
438 }
439 };
440
441 loop {
443 self.read_next_message()?;
444 let type_byte = self.conn.buffer_set.type_byte;
445
446 match type_byte {
447 msg_type::DATA_ROW => {
448 if !has_rows {
449 return Err(Error::Protocol(
450 "received DataRow but no RowDescription".into(),
451 ));
452 }
453 let cols = RowDescription::parse(&self.column_buffer)?;
454 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
455 handler.row(cols, row)?;
456 }
457 msg_type::COMMAND_COMPLETE => {
458 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
459 handler.result_end(cmd)?;
460 return Ok(());
461 }
462 msg_type::EMPTY_QUERY_RESPONSE => {
463 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
464 return Ok(());
465 }
466 _ => {
467 return Err(Error::Protocol(format!(
468 "unexpected message type in pipeline claim: '{}'",
469 type_byte as char
470 )));
471 }
472 }
473 }
474 }
475
476 fn claim_rows_cached_inner<H: BinaryHandler>(
478 &mut self,
479 handler: &mut H,
480 row_desc: Option<&[u8]>,
481 ) -> Result<()> {
482 loop {
484 self.read_next_message()?;
485 let type_byte = self.conn.buffer_set.type_byte;
486
487 match type_byte {
488 msg_type::DATA_ROW => {
489 let row_desc = row_desc.ok_or_else(|| {
490 Error::Protocol("received DataRow but no RowDescription cached".into())
491 })?;
492 let cols = RowDescription::parse(row_desc)?;
493 let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
494 handler.row(cols, row)?;
495 }
496 msg_type::COMMAND_COMPLETE => {
497 let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
498 handler.result_end(cmd)?;
499 return Ok(());
500 }
501 msg_type::EMPTY_QUERY_RESPONSE => {
502 EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
503 return Ok(());
504 }
505 _ => {
506 return Err(Error::Protocol(format!(
507 "unexpected message type in pipeline claim: '{}'",
508 type_byte as char
509 )));
510 }
511 }
512 }
513 }
514
515 fn read_next_message(&mut self) -> Result<()> {
517 loop {
518 self.conn.stream.read_message(&mut self.conn.buffer_set)?;
519 let type_byte = self.conn.buffer_set.type_byte;
520
521 if RawMessage::is_async_type(type_byte) {
523 continue;
524 }
525
526 if type_byte == msg_type::ERROR_RESPONSE {
528 let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
529 return Err(error.into_error());
530 }
531
532 return Ok(());
533 }
534 }
535
536 fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
538 Err(Error::Protocol(format!(
539 "expected {}, got '{}'",
540 expected, self.conn.buffer_set.type_byte as char
541 )))
542 }
543
544 pub fn pending_count(&self) -> usize {
546 self.queue_seq - self.claim_seq
547 }
548
549 pub fn is_aborted(&self) -> bool {
551 self.aborted
552 }
553}