sentinel_driver/pipeline/
mod.rs1pub mod batch;
2
3use std::sync::Arc;
4
5use bytes::BytesMut;
6
7use crate::error::{Error, Result};
8use crate::protocol::backend::BackendMessage;
9use crate::protocol::frontend;
10use crate::row::{parse_command_tag, CommandResult, Row, RowDescription};
11
12#[derive(Debug)]
14pub struct PipelineQuery {
15 pub sql: String,
16 pub param_types: Vec<u32>,
17 pub params: Vec<Option<Vec<u8>>>,
18}
19
20#[derive(Debug)]
22pub enum QueryResult {
23 Rows(Vec<Row>),
25 Command(CommandResult),
27}
28
29impl QueryResult {
30 pub fn into_rows(self) -> Result<Vec<Row>> {
32 match self {
33 QueryResult::Rows(rows) => Ok(rows),
34 QueryResult::Command(_) => Err(Error::Protocol(
35 "expected rows but got command result".to_string(),
36 )),
37 }
38 }
39
40 pub fn into_command(self) -> Result<CommandResult> {
42 match self {
43 QueryResult::Command(r) => Ok(r),
44 QueryResult::Rows(_) => Err(Error::Protocol(
45 "expected command result but got rows".to_string(),
46 )),
47 }
48 }
49}
50
51pub fn encode_pipeline(buf: &mut BytesMut, queries: &[PipelineQuery]) {
56 for q in queries {
57 let oids: Vec<u32> = q.param_types.clone();
59 frontend::parse(buf, "", &q.sql, &oids);
60
61 let param_refs: Vec<Option<&[u8]>> = q.params.iter().map(|p| p.as_deref()).collect();
63 frontend::bind(buf, "", "", ¶m_refs, &[]);
64
65 frontend::describe_portal(buf, "");
67
68 frontend::execute(buf, "", 0);
70 }
71
72 frontend::sync(buf);
74}
75
76pub(crate) async fn read_pipeline_responses(
86 conn: &mut crate::connection::stream::PgConnection,
87 count: usize,
88) -> Result<Vec<QueryResult>> {
89 let mut results = Vec::with_capacity(count);
90
91 for _ in 0..count {
92 expect_message(conn, "ParseComplete", |m| {
94 matches!(m, BackendMessage::ParseComplete)
95 })
96 .await?;
97
98 expect_message(conn, "BindComplete", |m| {
100 matches!(m, BackendMessage::BindComplete)
101 })
102 .await?;
103
104 let msg = conn.recv().await?;
106 let description = match msg {
107 BackendMessage::RowDescription { fields } => {
108 Some(Arc::new(RowDescription::new(fields)))
109 }
110 BackendMessage::NoData => None,
111 BackendMessage::ErrorResponse { fields } => {
112 return Err(Error::server(
113 fields.severity,
114 fields.code,
115 fields.message,
116 fields.detail,
117 fields.hint,
118 fields.position,
119 ));
120 }
121 other => {
122 return Err(Error::protocol(format!(
123 "expected RowDescription or NoData, got {other:?}"
124 )));
125 }
126 };
127
128 let result = read_query_result(conn, description).await?;
130 results.push(result);
131 }
132
133 let msg = conn.recv().await?;
135 match msg {
136 BackendMessage::ReadyForQuery { .. } => {}
137 BackendMessage::ErrorResponse { fields } => {
138 return Err(Error::server(
139 fields.severity,
140 fields.code,
141 fields.message,
142 fields.detail,
143 fields.hint,
144 fields.position,
145 ));
146 }
147 other => {
148 return Err(Error::protocol(format!(
149 "expected ReadyForQuery, got {other:?}"
150 )));
151 }
152 }
153
154 Ok(results)
155}
156
157async fn read_query_result(
159 conn: &mut crate::connection::stream::PgConnection,
160 description: Option<Arc<RowDescription>>,
161) -> Result<QueryResult> {
162 let mut rows = Vec::new();
163
164 loop {
165 let msg = conn.recv().await?;
166 match msg {
167 BackendMessage::DataRow { columns } => {
168 let desc = description
169 .as_ref()
170 .ok_or_else(|| Error::protocol("received DataRow without RowDescription"))?;
171 rows.push(Row::new(columns, Arc::clone(desc)));
172 }
173 BackendMessage::CommandComplete { tag } => {
174 if rows.is_empty() {
175 return Ok(QueryResult::Command(parse_command_tag(&tag)));
176 } else {
177 return Ok(QueryResult::Rows(rows));
178 }
179 }
180 BackendMessage::EmptyQueryResponse => {
181 return Ok(QueryResult::Command(CommandResult {
182 command: String::new(),
183 rows_affected: 0,
184 }));
185 }
186 BackendMessage::ErrorResponse { fields } => {
187 return Err(Error::server(
188 fields.severity,
189 fields.code,
190 fields.message,
191 fields.detail,
192 fields.hint,
193 fields.position,
194 ));
195 }
196 other => {
197 return Err(Error::protocol(format!(
198 "unexpected message in query result: {other:?}"
199 )));
200 }
201 }
202 }
203}
204
205async fn expect_message(
206 conn: &mut crate::connection::stream::PgConnection,
207 expected: &str,
208 check: impl FnOnce(&BackendMessage) -> bool,
209) -> Result<()> {
210 let msg = conn.recv().await?;
211 if check(&msg) {
212 Ok(())
213 } else if let BackendMessage::ErrorResponse { fields } = msg {
214 Err(Error::server(
215 fields.severity,
216 fields.code,
217 fields.message,
218 fields.detail,
219 fields.hint,
220 fields.position,
221 ))
222 } else {
223 Err(Error::protocol(format!("expected {expected}, got {msg:?}")))
224 }
225}