xitca_postgres/
prepare.rs

1use core::future::Future;
2
3use std::sync::Arc;
4
5use postgres_types::{Field, Kind, Oid};
6
7use super::{
8    BoxedFuture,
9    client::Client,
10    error::{DbError, Error, SqlState},
11    execute::{Execute, ExecuteBlocking},
12    iter::AsyncLendingIterator,
13    query::Query,
14    statement::{Statement, StatementNamed},
15    types::Type,
16};
17
18/// trait generic over preparing statement and canceling of prepared statement
19pub trait Prepare: Query + Sync {
20    fn _get_type(&self, oid: Oid) -> impl Future<Output = Result<Type, Error>> + Send;
21
22    #[doc(hidden)]
23    fn _get_type_boxed(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
24        Box::pin(self._get_type(oid))
25    }
26
27    // blocking version of [`Prepare::_get_type`].
28    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error>;
29}
30
31impl<T> Prepare for &T
32where
33    T: Prepare,
34{
35    #[inline]
36    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
37        T::_get_type(*self, oid).await
38    }
39
40    #[inline]
41    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
42        T::_get_type_blocking(*self, oid)
43    }
44}
45
46impl<T> Prepare for &mut T
47where
48    T: Prepare,
49{
50    #[inline]
51    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
52        T::_get_type(&**self, oid).await
53    }
54
55    #[inline]
56    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
57        T::_get_type_blocking(&**self, oid)
58    }
59}
60
61impl Prepare for Client {
62    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
63        if let Some(ty) = Type::from_oid(oid).or_else(|| self.type_(oid)) {
64            return Ok(ty);
65        }
66
67        let stmt = self.typeinfo_statement().await?;
68
69        let mut rows = stmt.bind([oid]).query(self).await?;
70        let row = rows.try_next().await?.ok_or_else(Error::unexpected)?;
71
72        let name = row.try_get::<String>(0)?;
73        let type_ = row.try_get::<i8>(1)?;
74        let elem_oid = row.try_get::<Oid>(2)?;
75        let rngsubtype = row.try_get::<Option<Oid>>(3)?;
76        let basetype = row.try_get::<Oid>(4)?;
77        let schema = row.try_get::<String>(5)?;
78        let relid = row.try_get::<Oid>(6)?;
79
80        let kind = if type_ == b'e' as i8 {
81            let variants = self.get_enum_variants(oid).await?;
82            Kind::Enum(variants)
83        } else if type_ == b'p' as i8 {
84            Kind::Pseudo
85        } else if basetype != 0 {
86            let type_ = self._get_type_boxed(basetype).await?;
87            Kind::Domain(type_)
88        } else if elem_oid != 0 {
89            let type_ = self._get_type_boxed(elem_oid).await?;
90            Kind::Array(type_)
91        } else if relid != 0 {
92            let fields = self.get_composite_fields(relid).await?;
93            Kind::Composite(fields)
94        } else if let Some(rngsubtype) = rngsubtype {
95            let type_ = self._get_type_boxed(rngsubtype).await?;
96            Kind::Range(type_)
97        } else {
98            Kind::Simple
99        };
100
101        let type_ = Type::new(name, oid, kind, schema);
102        self.set_type(oid, &type_);
103
104        Ok(type_)
105    }
106
107    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
108        if let Some(ty) = Type::from_oid(oid).or_else(|| self.type_(oid)) {
109            return Ok(ty);
110        }
111
112        let stmt = self.typeinfo_statement_blocking()?;
113
114        let rows = stmt.bind([oid]).query_blocking(self)?;
115        let row = rows.into_iter().next().ok_or_else(Error::unexpected)??;
116
117        let name = row.try_get::<String>(0)?;
118        let type_ = row.try_get::<i8>(1)?;
119        let elem_oid = row.try_get::<Oid>(2)?;
120        let rngsubtype = row.try_get::<Option<Oid>>(3)?;
121        let basetype = row.try_get::<Oid>(4)?;
122        let schema = row.try_get::<String>(5)?;
123        let relid = row.try_get::<Oid>(6)?;
124
125        let kind = if type_ == b'e' as i8 {
126            let variants = self.get_enum_variants_blocking(oid)?;
127            Kind::Enum(variants)
128        } else if type_ == b'p' as i8 {
129            Kind::Pseudo
130        } else if basetype != 0 {
131            let type_ = self._get_type_blocking(basetype)?;
132            Kind::Domain(type_)
133        } else if elem_oid != 0 {
134            let type_ = self._get_type_blocking(elem_oid)?;
135            Kind::Array(type_)
136        } else if relid != 0 {
137            let fields = self.get_composite_fields_blocking(relid)?;
138            Kind::Composite(fields)
139        } else if let Some(rngsubtype) = rngsubtype {
140            let type_ = self._get_type_blocking(rngsubtype)?;
141            Kind::Range(type_)
142        } else {
143            Kind::Simple
144        };
145
146        let type_ = Type::new(name, oid, kind, schema);
147        self.set_type(oid, &type_);
148
149        Ok(type_)
150    }
151}
152
153impl Prepare for Arc<Client> {
154    #[inline]
155    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
156        Client::_get_type(&**self, oid).await
157    }
158
159    #[inline]
160    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
161        Client::_get_type_blocking(&**self, oid)
162    }
163}
164
165const TYPEINFO_QUERY: StatementNamed = Statement::named(
166    "SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid \
167    FROM pg_catalog.pg_type t \
168    LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid \
169    INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid \
170    WHERE t.oid = $1",
171    &[],
172);
173
174// Range types weren't added until Postgres 9.2, so pg_range may not exist
175const TYPEINFO_FALLBACK_QUERY: StatementNamed = Statement::named(
176    "SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid \
177    FROM pg_catalog.pg_type t \
178    INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid \
179    WHERE t.oid = $1",
180    &[],
181);
182
183const TYPEINFO_ENUM_QUERY: StatementNamed = Statement::named(
184    "SELECT enumlabel \
185    FROM pg_catalog.pg_enum \
186    WHERE enumtypid = $1 \
187    ORDER BY enumsortorder",
188    &[],
189);
190
191// Postgres 9.0 didn't have enumsortorder
192const TYPEINFO_ENUM_FALLBACK_QUERY: StatementNamed = Statement::named(
193    "SELECT enumlabel \
194    FROM pg_catalog.pg_enum \
195    WHERE enumtypid = $1 \
196    ORDER BY oid",
197    &[],
198);
199
200const TYPEINFO_COMPOSITE_QUERY: StatementNamed = Statement::named(
201    "SELECT attname, atttypid \
202    FROM pg_catalog.pg_attribute \
203    WHERE attrelid = $1 \
204    AND NOT attisdropped \
205    AND attnum > 0 \
206    ORDER BY attnum",
207    &[],
208);
209
210impl Client {
211    async fn get_enum_variants(&self, oid: Oid) -> Result<Vec<String>, Error> {
212        let stmt = self.typeinfo_enum_statement().await?;
213        let mut rows = stmt.bind([oid]).query(self).await?;
214        let mut res = Vec::new();
215        while let Some(row) = rows.try_next().await? {
216            let variant = row.try_get(0)?;
217            res.push(variant);
218        }
219        Ok(res)
220    }
221
222    async fn get_composite_fields(&self, oid: Oid) -> Result<Vec<Field>, Error> {
223        let stmt = self.typeinfo_composite_statement().await?;
224        let mut rows = stmt.bind([oid]).query(self).await?;
225        let mut fields = Vec::new();
226        while let Some(row) = rows.try_next().await? {
227            let name = row.try_get(0)?;
228            let oid = row.try_get(1)?;
229            let type_ = self._get_type_boxed(oid).await?;
230            fields.push(Field::new(name, type_));
231        }
232        Ok(fields)
233    }
234
235    async fn typeinfo_statement(&self) -> Result<Statement, Error> {
236        if let Some(stmt) = self.typeinfo() {
237            return Ok(stmt);
238        }
239        let stmt = match TYPEINFO_QUERY.execute(self).await.map(|stmt| stmt.leak()) {
240            Ok(stmt) => stmt,
241            Err(e) => {
242                return if e
243                    .downcast_ref::<DbError>()
244                    .is_some_and(|e| SqlState::UNDEFINED_TABLE.eq(e.code()))
245                {
246                    TYPEINFO_FALLBACK_QUERY.execute(self).await.map(|stmt| stmt.leak())
247                } else {
248                    Err(e)
249                };
250            }
251        };
252        self.set_typeinfo(&stmt);
253        Ok(stmt)
254    }
255
256    async fn typeinfo_enum_statement(&self) -> Result<Statement, Error> {
257        if let Some(stmt) = self.typeinfo_enum() {
258            return Ok(stmt);
259        }
260        let stmt = match TYPEINFO_ENUM_QUERY.execute(self).await {
261            Ok(stmt) => stmt.leak(),
262            Err(e) => {
263                return if e
264                    .downcast_ref::<DbError>()
265                    .is_some_and(|e| SqlState::UNDEFINED_COLUMN.eq(e.code()))
266                {
267                    TYPEINFO_ENUM_FALLBACK_QUERY.execute(self).await.map(|stmt| stmt.leak())
268                } else {
269                    Err(e)
270                };
271            }
272        };
273        self.set_typeinfo_enum(&stmt);
274        Ok(stmt)
275    }
276
277    async fn typeinfo_composite_statement(&self) -> Result<Statement, Error> {
278        if let Some(stmt) = self.typeinfo_composite() {
279            return Ok(stmt);
280        }
281        let stmt = TYPEINFO_COMPOSITE_QUERY.execute(self).await?.leak();
282        self.set_typeinfo_composite(&stmt);
283        Ok(stmt)
284    }
285}
286
287impl Client {
288    fn get_enum_variants_blocking(&self, oid: Oid) -> Result<Vec<String>, Error> {
289        let stmt = self.typeinfo_enum_statement_blocking()?;
290        stmt.bind([oid])
291            .query_blocking(self)?
292            .into_iter()
293            .map(|row| row?.try_get(0))
294            .collect()
295    }
296
297    fn get_composite_fields_blocking(&self, oid: Oid) -> Result<Vec<Field>, Error> {
298        let stmt = self.typeinfo_composite_statement_blocking()?;
299        stmt.bind([oid])
300            .query_blocking(self)?
301            .into_iter()
302            .map(|row| {
303                let row = row?;
304                let name = row.try_get(0)?;
305                let oid = row.try_get(1)?;
306                let type_ = self._get_type_blocking(oid)?;
307                Ok(Field::new(name, type_))
308            })
309            .collect()
310    }
311
312    fn typeinfo_statement_blocking(&self) -> Result<Statement, Error> {
313        if let Some(stmt) = self.typeinfo() {
314            return Ok(stmt);
315        }
316        let stmt = match TYPEINFO_QUERY.execute_blocking(self) {
317            Ok(stmt) => stmt.leak(),
318            Err(e) => {
319                return if e
320                    .downcast_ref::<DbError>()
321                    .is_some_and(|e| SqlState::UNDEFINED_TABLE.eq(e.code()))
322                {
323                    TYPEINFO_FALLBACK_QUERY.execute_blocking(self).map(|stmt| stmt.leak())
324                } else {
325                    Err(e)
326                };
327            }
328        };
329        self.set_typeinfo(&stmt);
330        Ok(stmt)
331    }
332
333    fn typeinfo_enum_statement_blocking(&self) -> Result<Statement, Error> {
334        if let Some(stmt) = self.typeinfo_enum() {
335            return Ok(stmt);
336        }
337        let stmt = match TYPEINFO_ENUM_QUERY.execute_blocking(self) {
338            Ok(stmt) => stmt.leak(),
339            Err(e) => {
340                return if e
341                    .downcast_ref::<DbError>()
342                    .is_some_and(|e| SqlState::UNDEFINED_COLUMN.eq(e.code()))
343                {
344                    TYPEINFO_ENUM_FALLBACK_QUERY
345                        .execute_blocking(self)
346                        .map(|stmt| stmt.leak())
347                } else {
348                    Err(e)
349                };
350            }
351        };
352        self.set_typeinfo_enum(&stmt);
353        Ok(stmt)
354    }
355
356    fn typeinfo_composite_statement_blocking(&self) -> Result<Statement, Error> {
357        if let Some(stmt) = self.typeinfo_composite() {
358            return Ok(stmt);
359        }
360        let stmt = TYPEINFO_COMPOSITE_QUERY.execute_blocking(self)?.leak();
361        self.set_typeinfo_composite(&stmt);
362        Ok(stmt)
363    }
364}