1use core::{ops::Deref, sync::atomic::Ordering};
4
5use super::{
6 column::Column,
7 driver::codec::{encode::StatementCancel, AsParams},
8 query::Query,
9 types::{ToSql, Type},
10};
11
12pub struct StatementGuarded<'a, C>
18where
19 C: Query,
20{
21 stmt: Option<Statement>,
22 cli: &'a C,
23}
24
25impl<C> AsRef<Statement> for StatementGuarded<'_, C>
26where
27 C: Query,
28{
29 #[inline]
30 fn as_ref(&self) -> &Statement {
31 self
32 }
33}
34
35impl<C> Deref for StatementGuarded<'_, C>
36where
37 C: Query,
38{
39 type Target = Statement;
40
41 fn deref(&self) -> &Self::Target {
42 self.stmt.as_ref().unwrap()
43 }
44}
45
46impl<C> Drop for StatementGuarded<'_, C>
47where
48 C: Query,
49{
50 fn drop(&mut self) {
51 if let Some(stmt) = self.stmt.take() {
52 let _ = self.cli._send_encode_query(StatementCancel { name: stmt.name() });
53 }
54 }
55}
56
57impl<C> StatementGuarded<'_, C>
58where
59 C: Query,
60{
61 pub fn leak(mut self) -> Statement {
64 self.stmt.take().unwrap()
65 }
66}
67
68#[derive(Default)]
81pub struct Statement {
82 name: Box<str>,
83 params: Box<[Type]>,
84 columns: Box<[Column]>,
85}
86
87impl Statement {
88 pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Self {
89 Self {
90 name: name.into_boxed_str(),
91 params: params.into_boxed_slice(),
92 columns: columns.into_boxed_slice(),
93 }
94 }
95
96 pub(crate) fn duplicate(&self) -> Self {
97 Self {
98 name: self.name.clone(),
99 params: self.params.clone(),
100 columns: self.columns.clone(),
101 }
102 }
103
104 pub(crate) fn name(&self) -> &str {
105 &self.name
106 }
107
108 #[inline]
113 pub const fn named<'a>(stmt: &'a str, types: &'a [Type]) -> StatementNamed<'a> {
114 StatementNamed { stmt, types }
115 }
116
117 #[inline]
120 pub const fn unnamed<'a>(stmt: &'a str, types: &'a [Type]) -> StatementUnnamed<'a> {
121 StatementUnnamed { stmt, types }
122 }
123
124 #[inline]
139 pub fn bind<P>(&self, params: P) -> StatementQuery<'_, P>
140 where
141 P: AsParams,
142 {
143 StatementQuery { stmt: self, params }
144 }
145
146 #[inline]
156 pub fn bind_dyn<'p, 't>(
157 &self,
158 params: &'p [&'t (dyn ToSql + Sync)],
159 ) -> StatementQuery<'_, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + 'p> {
160 self.bind(params.iter().cloned())
161 }
162
163 #[inline]
165 pub fn params(&self) -> &[Type] {
166 &self.params
167 }
168
169 #[inline]
171 pub fn columns(&self) -> &[Column] {
172 &self.columns
173 }
174
175 #[inline]
177 pub fn into_guarded<C>(self, cli: &C) -> StatementGuarded<C>
178 where
179 C: Query,
180 {
181 StatementGuarded { stmt: Some(self), cli }
182 }
183}
184
185#[derive(Clone, Copy)]
186pub struct StatementNamed<'a> {
187 pub(crate) stmt: &'a str,
188 pub(crate) types: &'a [Type],
189}
190
191impl StatementNamed<'_> {
192 fn name() -> String {
193 let id = crate::NEXT_ID.fetch_add(1, Ordering::Relaxed);
194 format!("s{id}")
195 }
196}
197
198pub(crate) struct StatementCreate<'a, 'c, C> {
199 pub(crate) name: String,
200 pub(crate) stmt: &'a str,
201 pub(crate) types: &'a [Type],
202 pub(crate) cli: &'c C,
203}
204
205impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreate<'a, 'c, C> {
206 fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
207 Self {
208 name: StatementNamed::name(),
209 stmt: stmt.stmt,
210 types: stmt.types,
211 cli,
212 }
213 }
214}
215
216pub(crate) struct StatementCreateBlocking<'a, 'c, C> {
217 pub(crate) name: String,
218 pub(crate) stmt: &'a str,
219 pub(crate) types: &'a [Type],
220 pub(crate) cli: &'c C,
221}
222
223impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreateBlocking<'a, 'c, C> {
224 fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
225 Self {
226 name: StatementNamed::name(),
227 stmt: stmt.stmt,
228 types: stmt.types,
229 cli,
230 }
231 }
232}
233
234pub struct StatementUnnamed<'a> {
238 pub(crate) stmt: &'a str,
239 pub(crate) types: &'a [Type],
240}
241
242impl<'a> StatementUnnamed<'a> {
243 #[inline]
245 pub fn bind<P>(self, params: P) -> StatementUnnamedBind<'a, P> {
246 StatementUnnamedBind {
247 stmt: self.stmt,
248 types: self.types,
249 params,
250 }
251 }
252
253 #[inline]
255 pub fn bind_dyn<'p, 't>(
256 self,
257 params: &'p [&'t (dyn ToSql + Sync)],
258 ) -> StatementUnnamedBind<'a, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + 'p> {
259 self.bind(params.iter().cloned())
260 }
261}
262
263pub struct StatementQuery<'a, P> {
265 pub(crate) stmt: &'a Statement,
266 pub(crate) params: P,
267}
268
269pub struct StatementUnnamedBind<'a, P> {
271 stmt: &'a str,
272 types: &'a [Type],
273 params: P,
274}
275
276pub(crate) struct StatementUnnamedQuery<'a, 'c, P, C> {
277 pub(crate) stmt: &'a str,
278 pub(crate) types: &'a [Type],
279 pub(crate) params: P,
280 pub(crate) cli: &'c C,
281}
282
283impl<'a, 'c, P, C> From<(StatementUnnamedBind<'a, P>, &'c C)> for StatementUnnamedQuery<'a, 'c, P, C> {
284 fn from((bind, cli): (StatementUnnamedBind<'a, P>, &'c C)) -> Self {
285 Self {
286 stmt: bind.stmt,
287 types: bind.types,
288 params: bind.params,
289 cli,
290 }
291 }
292}
293
294#[cfg(feature = "compat")]
295pub(crate) mod compat {
296 use core::ops::Deref;
297
298 use std::sync::Arc;
299
300 use super::{Query, Statement, StatementCancel};
301
302 #[derive(Clone)]
308 pub struct StatementGuarded<C>
309 where
310 C: Query,
311 {
312 inner: Arc<_StatementGuarded<C>>,
313 }
314
315 struct _StatementGuarded<C>
316 where
317 C: Query,
318 {
319 stmt: Statement,
320 cli: C,
321 }
322
323 impl<C> Drop for _StatementGuarded<C>
324 where
325 C: Query,
326 {
327 fn drop(&mut self) {
328 let _ = self.cli._send_encode_query(StatementCancel { name: self.stmt.name() });
329 }
330 }
331
332 impl<C> Deref for StatementGuarded<C>
333 where
334 C: Query,
335 {
336 type Target = Statement;
337
338 fn deref(&self) -> &Self::Target {
339 &self.inner.stmt
340 }
341 }
342
343 impl<C> AsRef<Statement> for StatementGuarded<C>
344 where
345 C: Query,
346 {
347 fn as_ref(&self) -> &Statement {
348 &self.inner.stmt
349 }
350 }
351
352 impl<C> StatementGuarded<C>
353 where
354 C: Query,
355 {
356 pub fn new(stmt: Statement, cli: C) -> Self {
358 Self {
359 inner: Arc::new(_StatementGuarded { stmt, cli }),
360 }
361 }
362 }
363}
364
365#[cfg(test)]
366mod test {
367 use core::future::IntoFuture;
368
369 use crate::{
370 error::{DbError, SqlState},
371 execute::Execute,
372 iter::AsyncLendingIterator,
373 statement::Statement,
374 Postgres,
375 };
376
377 #[tokio::test]
378 async fn cancel_statement() {
379 let (cli, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
380 .connect()
381 .await
382 .unwrap();
383
384 tokio::task::spawn(drv.into_future());
385
386 std::path::Path::new("./samples/test.sql").execute(&cli).await.unwrap();
387
388 let stmt = Statement::named("SELECT id, name FROM foo ORDER BY id", &[])
389 .execute(&cli)
390 .await
391 .unwrap();
392
393 let stmt_raw = stmt.duplicate();
394
395 drop(stmt);
396
397 let mut stream = stmt_raw.query(&cli).await.unwrap();
398
399 let e = stream.try_next().await.err().unwrap();
400
401 let e = e.downcast_ref::<DbError>().unwrap();
402
403 assert_eq!(e.code(), &SqlState::INVALID_SQL_STATEMENT_NAME);
404 }
405}