Skip to main content

sentinel_driver/connection/
query.rs

1use super::{
2    frontend, pipeline, BackendMessage, BytesMut, Connection, Duration, Error, Oid, PipelineBatch,
3    Result, Row, ToSql,
4};
5
6use crate::row::{self, SimpleQueryMessage, SimpleQueryRow};
7
8impl Connection {
9    /// Execute a query that returns rows.
10    ///
11    /// Parameters are encoded in binary format.
12    ///
13    /// ```rust,no_run
14    /// # async fn example(conn: &mut sentinel_driver::Connection) -> sentinel_driver::Result<()> {
15    /// let rows = conn.query("SELECT * FROM users WHERE id = $1", &[&42i32]).await?;
16    /// # Ok(())
17    /// # }
18    /// ```
19    pub async fn query(&mut self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>> {
20        if let Some(timeout) = self.query_timeout {
21            return self.query_with_timeout(sql, params, timeout).await;
22        }
23
24        let result = self.query_internal(sql, params).await?;
25        match result {
26            pipeline::QueryResult::Rows(rows) => Ok(rows),
27            pipeline::QueryResult::Command(_) => Ok(Vec::new()),
28        }
29    }
30
31    /// Execute a query that returns a single row.
32    ///
33    /// Returns an error if no rows are returned.
34    pub async fn query_one(&mut self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> Result<Row> {
35        let rows = self.query(sql, params).await?;
36        rows.into_iter()
37            .next()
38            .ok_or_else(|| Error::Protocol("query returned no rows".into()))
39    }
40
41    /// Execute a query that returns an optional single row.
42    pub async fn query_opt(
43        &mut self,
44        sql: &str,
45        params: &[&(dyn ToSql + Sync)],
46    ) -> Result<Option<Row>> {
47        let rows = self.query(sql, params).await?;
48        Ok(rows.into_iter().next())
49    }
50
51    /// Execute a non-SELECT query (INSERT, UPDATE, DELETE, etc.).
52    ///
53    /// Returns the number of rows affected.
54    pub async fn execute(&mut self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> Result<u64> {
55        if let Some(timeout) = self.query_timeout {
56            return self.execute_with_timeout(sql, params, timeout).await;
57        }
58
59        let result = self.query_internal(sql, params).await?;
60        match result {
61            pipeline::QueryResult::Command(r) => Ok(r.rows_affected),
62            pipeline::QueryResult::Rows(_) => Ok(0),
63        }
64    }
65
66    /// Execute a query with a timeout.
67    ///
68    /// If the query does not complete within `timeout`, a cancel request
69    /// is sent to the server and the connection is marked as broken.
70    pub async fn query_with_timeout(
71        &mut self,
72        sql: &str,
73        params: &[&(dyn ToSql + Sync)],
74        timeout: Duration,
75    ) -> Result<Vec<Row>> {
76        let cancel_token = self.cancel_token();
77
78        match tokio::time::timeout(timeout, self.query_internal(sql, params)).await {
79            Ok(result) => {
80                let result = result?;
81                match result {
82                    pipeline::QueryResult::Rows(rows) => Ok(rows),
83                    pipeline::QueryResult::Command(_) => Ok(Vec::new()),
84                }
85            }
86            Err(_elapsed) => {
87                self.is_broken = true;
88                // Fire-and-forget cancel
89                tokio::spawn(async move {
90                    cancel_token.cancel().await.ok();
91                });
92                Err(Error::Timeout(format!(
93                    "query timeout after {}ms",
94                    timeout.as_millis()
95                )))
96            }
97        }
98    }
99
100    /// Execute a non-SELECT query with a timeout.
101    ///
102    /// If the query does not complete within `timeout`, a cancel request
103    /// is sent to the server and the connection is marked as broken.
104    pub async fn execute_with_timeout(
105        &mut self,
106        sql: &str,
107        params: &[&(dyn ToSql + Sync)],
108        timeout: Duration,
109    ) -> Result<u64> {
110        let cancel_token = self.cancel_token();
111
112        match tokio::time::timeout(timeout, self.query_internal(sql, params)).await {
113            Ok(result) => {
114                let result = result?;
115                match result {
116                    pipeline::QueryResult::Command(r) => Ok(r.rows_affected),
117                    pipeline::QueryResult::Rows(_) => Ok(0),
118                }
119            }
120            Err(_elapsed) => {
121                self.is_broken = true;
122                tokio::spawn(async move {
123                    cancel_token.cancel().await.ok();
124                });
125                Err(Error::Timeout(format!(
126                    "query timeout after {}ms",
127                    timeout.as_millis()
128                )))
129            }
130        }
131    }
132
133    /// Execute a simple query (no parameters, text protocol).
134    ///
135    /// Returns row data (in text format) and command completions. Useful
136    /// for DDL statements, multi-statement queries, and queries where you
137    /// don't need binary-decoded typed values.
138    ///
139    /// ```rust,no_run
140    /// # async fn example(conn: &mut sentinel_driver::Connection) -> sentinel_driver::Result<()> {
141    /// use sentinel_driver::SimpleQueryMessage;
142    ///
143    /// let messages = conn.simple_query("SELECT 1 AS n; SELECT 'hello' AS greeting").await?;
144    /// for msg in &messages {
145    ///     match msg {
146    ///         SimpleQueryMessage::Row(row) => {
147    ///             println!("value: {:?}", row.get(0));
148    ///         }
149    ///         SimpleQueryMessage::CommandComplete(n) => {
150    ///             println!("rows: {n}");
151    ///         }
152    ///     }
153    /// }
154    /// # Ok(())
155    /// # }
156    /// ```
157    pub async fn simple_query(&mut self, sql: &str) -> Result<Vec<SimpleQueryMessage>> {
158        self.instr().on_event(&crate::Event::ExecuteStart {
159            stmt: crate::StmtRef::Inline { sql },
160            param_count: 0,
161        });
162        let started = std::time::Instant::now();
163        let res = self.simple_query_inner(sql).await;
164        let duration = started.elapsed();
165        let (rows, outcome) = match &res {
166            Ok(msgs) => {
167                let r = msgs
168                    .iter()
169                    .filter_map(|m| match m {
170                        SimpleQueryMessage::CommandComplete(n) => Some(*n),
171                        SimpleQueryMessage::Row(_) => None,
172                    })
173                    .sum::<u64>();
174                (r, crate::Outcome::Ok)
175            }
176            Err(e) => (0, crate::Outcome::Err(e)),
177        };
178        self.instr().on_event(&crate::Event::ExecuteFinish {
179            stmt: crate::StmtRef::Inline { sql },
180            rows,
181            duration,
182            outcome,
183        });
184        res
185    }
186
187    async fn simple_query_inner(&mut self, sql: &str) -> Result<Vec<SimpleQueryMessage>> {
188        frontend::query(self.conn.write_buf(), sql);
189        self.conn.send().await?;
190
191        let mut results = Vec::new();
192
193        loop {
194            match self.conn.recv().await? {
195                BackendMessage::DataRow { columns } => {
196                    // Extract text-format column values from DataRow
197                    let mut text_columns = Vec::with_capacity(columns.len());
198                    for i in 0..columns.len() {
199                        let value = columns
200                            .get(i)
201                            .map(|bytes| String::from_utf8_lossy(&bytes).into_owned());
202                        text_columns.push(value);
203                    }
204                    results.push(SimpleQueryMessage::Row(SimpleQueryRow::new(text_columns)));
205                }
206                BackendMessage::CommandComplete { tag } => {
207                    let parsed = row::parse_command_tag(&tag);
208                    results.push(SimpleQueryMessage::CommandComplete(parsed.rows_affected));
209                }
210                BackendMessage::ReadyForQuery { transaction_status } => {
211                    self.transaction_status = transaction_status;
212                    break;
213                }
214                BackendMessage::ErrorResponse { fields } => {
215                    self.drain_until_ready().await.ok();
216                    return Err(Error::server(
217                        fields.severity,
218                        fields.code,
219                        fields.message,
220                        fields.detail,
221                        fields.hint,
222                        fields.position,
223                    ));
224                }
225                _ => {}
226            }
227        }
228
229        Ok(results)
230    }
231
232    // ── query_typed ────────────────────────────────────
233
234    /// Execute a query with inline parameter types, skipping the prepare step.
235    ///
236    /// Instead of a separate Prepare round-trip, the parameter types are
237    /// specified directly in the Parse message. This saves one round-trip
238    /// compared to [`query()`](Self::query) at the cost of requiring the
239    /// caller to specify types explicitly.
240    ///
241    /// ```rust,no_run
242    /// # async fn example(conn: &mut sentinel_driver::Connection) -> sentinel_driver::Result<()> {
243    /// use sentinel_driver::Oid;
244    ///
245    /// let rows = conn.query_typed(
246    ///     "SELECT $1::int4 + $2::int4 AS sum",
247    ///     &[(&1i32, Oid::INT4), (&2i32, Oid::INT4)],
248    /// ).await?;
249    /// # Ok(())
250    /// # }
251    /// ```
252    pub async fn query_typed(
253        &mut self,
254        sql: &str,
255        params: &[(&(dyn ToSql + Sync), Oid)],
256    ) -> Result<Vec<Row>> {
257        let result = self.query_typed_internal(sql, params).await?;
258        match result {
259            pipeline::QueryResult::Rows(rows) => Ok(rows),
260            pipeline::QueryResult::Command(_) => Ok(Vec::new()),
261        }
262    }
263
264    /// Execute a typed query that returns a single row.
265    pub async fn query_typed_one(
266        &mut self,
267        sql: &str,
268        params: &[(&(dyn ToSql + Sync), Oid)],
269    ) -> Result<Row> {
270        let rows = self.query_typed(sql, params).await?;
271        rows.into_iter()
272            .next()
273            .ok_or_else(|| Error::Protocol("query returned no rows".into()))
274    }
275
276    /// Execute a typed query that returns an optional single row.
277    pub async fn query_typed_opt(
278        &mut self,
279        sql: &str,
280        params: &[(&(dyn ToSql + Sync), Oid)],
281    ) -> Result<Option<Row>> {
282        let rows = self.query_typed(sql, params).await?;
283        Ok(rows.into_iter().next())
284    }
285
286    /// Execute a typed non-SELECT query, returning rows affected.
287    pub async fn execute_typed(
288        &mut self,
289        sql: &str,
290        params: &[(&(dyn ToSql + Sync), Oid)],
291    ) -> Result<u64> {
292        let result = self.query_typed_internal(sql, params).await?;
293        match result {
294            pipeline::QueryResult::Command(r) => Ok(r.rows_affected),
295            pipeline::QueryResult::Rows(_) => Ok(0),
296        }
297    }
298
299    async fn query_typed_internal(
300        &mut self,
301        sql: &str,
302        params: &[(&(dyn ToSql + Sync), Oid)],
303    ) -> Result<pipeline::QueryResult> {
304        self.instr().on_event(&crate::Event::ExecuteStart {
305            stmt: crate::StmtRef::Inline { sql },
306            param_count: params.len(),
307        });
308        let started = std::time::Instant::now();
309        let res = self.query_typed_internal_inner(sql, params).await;
310        let duration = started.elapsed();
311        let (rows, outcome) = match &res {
312            Ok(pipeline::QueryResult::Rows(v)) => (v.len() as u64, crate::Outcome::Ok),
313            Ok(pipeline::QueryResult::Command(r)) => (r.rows_affected, crate::Outcome::Ok),
314            Err(e) => (0, crate::Outcome::Err(e)),
315        };
316        self.instr().on_event(&crate::Event::ExecuteFinish {
317            stmt: crate::StmtRef::Inline { sql },
318            rows,
319            duration,
320            outcome,
321        });
322        res
323    }
324
325    async fn query_typed_internal_inner(
326        &mut self,
327        sql: &str,
328        params: &[(&(dyn ToSql + Sync), Oid)],
329    ) -> Result<pipeline::QueryResult> {
330        let param_types: Vec<u32> = params.iter().map(|(_, oid)| oid.0).collect();
331        let mut encoded_params: Vec<Option<Vec<u8>>> = Vec::with_capacity(params.len());
332
333        for (value, _) in params {
334            if value.is_null() {
335                encoded_params.push(None);
336            } else {
337                let mut buf = BytesMut::new();
338                value.to_sql(&mut buf)?;
339                encoded_params.push(Some(buf.to_vec()));
340            }
341        }
342
343        let mut batch = PipelineBatch::new();
344        batch.add(sql.to_string(), param_types, encoded_params);
345
346        let mut results = batch.execute(&mut self.conn).await?;
347        results
348            .pop()
349            .ok_or_else(|| Error::protocol("pipeline returned no results"))
350    }
351}