xitca_postgres/
prepare.rs

1use postgres_types::{Field, Kind, Oid};
2
3use super::{
4    client::Client,
5    error::{DbError, Error, SqlState},
6    execute::{Execute, ExecuteBlocking},
7    iter::AsyncLendingIterator,
8    query::Query,
9    statement::{Statement, StatementNamed},
10    types::Type,
11    BoxedFuture,
12};
13
14/// trait generic over preparing statement and canceling of prepared statement
15pub trait Prepare: Query + Sync {
16    // get type is called recursively so a boxed future is needed.
17    fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>>;
18
19    // blocking version of [`Prepare::_get_type`].
20    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error>;
21}
22
23impl Prepare for Client {
24    fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
25        Box::pin(async move {
26            if let Some(ty) = Type::from_oid(oid).or_else(|| self.type_(oid)) {
27                return Ok(ty);
28            }
29
30            let stmt = self.typeinfo_statement().await?;
31
32            let mut rows = stmt.bind([oid]).query(self).await?;
33            let row = rows.try_next().await?.ok_or_else(Error::unexpected)?;
34
35            let name = row.try_get::<String>(0)?;
36            let type_ = row.try_get::<i8>(1)?;
37            let elem_oid = row.try_get::<Oid>(2)?;
38            let rngsubtype = row.try_get::<Option<Oid>>(3)?;
39            let basetype = row.try_get::<Oid>(4)?;
40            let schema = row.try_get::<String>(5)?;
41            let relid = row.try_get::<Oid>(6)?;
42
43            let kind = if type_ == b'e' as i8 {
44                let variants = self.get_enum_variants(oid).await?;
45                Kind::Enum(variants)
46            } else if type_ == b'p' as i8 {
47                Kind::Pseudo
48            } else if basetype != 0 {
49                let type_ = self._get_type(basetype).await?;
50                Kind::Domain(type_)
51            } else if elem_oid != 0 {
52                let type_ = self._get_type(elem_oid).await?;
53                Kind::Array(type_)
54            } else if relid != 0 {
55                let fields = self.get_composite_fields(relid).await?;
56                Kind::Composite(fields)
57            } else if let Some(rngsubtype) = rngsubtype {
58                let type_ = self._get_type(rngsubtype).await?;
59                Kind::Range(type_)
60            } else {
61                Kind::Simple
62            };
63
64            let type_ = Type::new(name, oid, kind, schema);
65            self.set_type(oid, &type_);
66
67            Ok(type_)
68        })
69    }
70
71    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
72        if let Some(ty) = Type::from_oid(oid).or_else(|| self.type_(oid)) {
73            return Ok(ty);
74        }
75
76        let stmt = self.typeinfo_statement_blocking()?;
77
78        let rows = stmt.bind([oid]).query_blocking(self)?;
79        let row = rows.into_iter().next().ok_or_else(Error::unexpected)??;
80
81        let name = row.try_get::<String>(0)?;
82        let type_ = row.try_get::<i8>(1)?;
83        let elem_oid = row.try_get::<Oid>(2)?;
84        let rngsubtype = row.try_get::<Option<Oid>>(3)?;
85        let basetype = row.try_get::<Oid>(4)?;
86        let schema = row.try_get::<String>(5)?;
87        let relid = row.try_get::<Oid>(6)?;
88
89        let kind = if type_ == b'e' as i8 {
90            let variants = self.get_enum_variants_blocking(oid)?;
91            Kind::Enum(variants)
92        } else if type_ == b'p' as i8 {
93            Kind::Pseudo
94        } else if basetype != 0 {
95            let type_ = self._get_type_blocking(basetype)?;
96            Kind::Domain(type_)
97        } else if elem_oid != 0 {
98            let type_ = self._get_type_blocking(elem_oid)?;
99            Kind::Array(type_)
100        } else if relid != 0 {
101            let fields = self.get_composite_fields_blocking(relid)?;
102            Kind::Composite(fields)
103        } else if let Some(rngsubtype) = rngsubtype {
104            let type_ = self._get_type_blocking(rngsubtype)?;
105            Kind::Range(type_)
106        } else {
107            Kind::Simple
108        };
109
110        let type_ = Type::new(name, oid, kind, schema);
111        self.set_type(oid, &type_);
112
113        Ok(type_)
114    }
115}
116
117const TYPEINFO_QUERY: StatementNamed = Statement::named(
118    "SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid \
119    FROM pg_catalog.pg_type t \
120    LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid \
121    INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid \
122    WHERE t.oid = $1",
123    &[],
124);
125
126// Range types weren't added until Postgres 9.2, so pg_range may not exist
127const TYPEINFO_FALLBACK_QUERY: StatementNamed = Statement::named(
128    "SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid \
129    FROM pg_catalog.pg_type t \
130    INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid \
131    WHERE t.oid = $1",
132    &[],
133);
134
135const TYPEINFO_ENUM_QUERY: StatementNamed = Statement::named(
136    "SELECT enumlabel \
137    FROM pg_catalog.pg_enum \
138    WHERE enumtypid = $1 \
139    ORDER BY enumsortorder",
140    &[],
141);
142
143// Postgres 9.0 didn't have enumsortorder
144const TYPEINFO_ENUM_FALLBACK_QUERY: StatementNamed = Statement::named(
145    "SELECT enumlabel \
146    FROM pg_catalog.pg_enum \
147    WHERE enumtypid = $1 \
148    ORDER BY oid",
149    &[],
150);
151
152const TYPEINFO_COMPOSITE_QUERY: StatementNamed = Statement::named(
153    "SELECT attname, atttypid \
154    FROM pg_catalog.pg_attribute \
155    WHERE attrelid = $1 \
156    AND NOT attisdropped \
157    AND attnum > 0 \
158    ORDER BY attnum",
159    &[],
160);
161
162impl Client {
163    async fn get_enum_variants(&self, oid: Oid) -> Result<Vec<String>, Error> {
164        let stmt = self.typeinfo_enum_statement().await?;
165        let mut rows = stmt.bind([oid]).query(self).await?;
166        let mut res = Vec::new();
167        while let Some(row) = rows.try_next().await? {
168            let variant = row.try_get(0)?;
169            res.push(variant);
170        }
171        Ok(res)
172    }
173
174    async fn get_composite_fields(&self, oid: Oid) -> Result<Vec<Field>, Error> {
175        let stmt = self.typeinfo_composite_statement().await?;
176        let mut rows = stmt.bind([oid]).query(self).await?;
177        let mut fields = Vec::new();
178        while let Some(row) = rows.try_next().await? {
179            let name = row.try_get(0)?;
180            let oid = row.try_get(1)?;
181            let type_ = self._get_type(oid).await?;
182            fields.push(Field::new(name, type_));
183        }
184        Ok(fields)
185    }
186
187    async fn typeinfo_statement(&self) -> Result<Statement, Error> {
188        if let Some(stmt) = self.typeinfo() {
189            return Ok(stmt);
190        }
191        let stmt = match TYPEINFO_QUERY.execute(self).await.map(|stmt| stmt.leak()) {
192            Ok(stmt) => stmt,
193            Err(e) => {
194                return if e
195                    .downcast_ref::<DbError>()
196                    .is_some_and(|e| SqlState::UNDEFINED_TABLE.eq(e.code()))
197                {
198                    TYPEINFO_FALLBACK_QUERY.execute(self).await.map(|stmt| stmt.leak())
199                } else {
200                    Err(e)
201                }
202            }
203        };
204        self.set_typeinfo(&stmt);
205        Ok(stmt)
206    }
207
208    async fn typeinfo_enum_statement(&self) -> Result<Statement, Error> {
209        if let Some(stmt) = self.typeinfo_enum() {
210            return Ok(stmt);
211        }
212        let stmt = match TYPEINFO_ENUM_QUERY.execute(self).await {
213            Ok(stmt) => stmt.leak(),
214            Err(e) => {
215                return if e
216                    .downcast_ref::<DbError>()
217                    .is_some_and(|e| SqlState::UNDEFINED_COLUMN.eq(e.code()))
218                {
219                    TYPEINFO_ENUM_FALLBACK_QUERY.execute(self).await.map(|stmt| stmt.leak())
220                } else {
221                    Err(e)
222                }
223            }
224        };
225        self.set_typeinfo_enum(&stmt);
226        Ok(stmt)
227    }
228
229    async fn typeinfo_composite_statement(&self) -> Result<Statement, Error> {
230        if let Some(stmt) = self.typeinfo_composite() {
231            return Ok(stmt);
232        }
233        let stmt = TYPEINFO_COMPOSITE_QUERY.execute(self).await?.leak();
234        self.set_typeinfo_composite(&stmt);
235        Ok(stmt)
236    }
237}
238
239impl Client {
240    fn get_enum_variants_blocking(&self, oid: Oid) -> Result<Vec<String>, Error> {
241        let stmt = self.typeinfo_enum_statement_blocking()?;
242        stmt.bind([oid])
243            .query_blocking(self)?
244            .into_iter()
245            .map(|row| row?.try_get(0))
246            .collect()
247    }
248
249    fn get_composite_fields_blocking(&self, oid: Oid) -> Result<Vec<Field>, Error> {
250        let stmt = self.typeinfo_composite_statement_blocking()?;
251        stmt.bind([oid])
252            .query_blocking(self)?
253            .into_iter()
254            .map(|row| {
255                let row = row?;
256                let name = row.try_get(0)?;
257                let oid = row.try_get(1)?;
258                let type_ = self._get_type_blocking(oid)?;
259                Ok(Field::new(name, type_))
260            })
261            .collect()
262    }
263
264    fn typeinfo_statement_blocking(&self) -> Result<Statement, Error> {
265        if let Some(stmt) = self.typeinfo() {
266            return Ok(stmt);
267        }
268        let stmt = match TYPEINFO_QUERY.execute_blocking(self) {
269            Ok(stmt) => stmt.leak(),
270            Err(e) => {
271                return if e
272                    .downcast_ref::<DbError>()
273                    .is_some_and(|e| SqlState::UNDEFINED_TABLE.eq(e.code()))
274                {
275                    TYPEINFO_FALLBACK_QUERY.execute_blocking(self).map(|stmt| stmt.leak())
276                } else {
277                    Err(e)
278                }
279            }
280        };
281        self.set_typeinfo(&stmt);
282        Ok(stmt)
283    }
284
285    fn typeinfo_enum_statement_blocking(&self) -> Result<Statement, Error> {
286        if let Some(stmt) = self.typeinfo_enum() {
287            return Ok(stmt);
288        }
289        let stmt = match TYPEINFO_ENUM_QUERY.execute_blocking(self) {
290            Ok(stmt) => stmt.leak(),
291            Err(e) => {
292                return if e
293                    .downcast_ref::<DbError>()
294                    .is_some_and(|e| SqlState::UNDEFINED_COLUMN.eq(e.code()))
295                {
296                    TYPEINFO_ENUM_FALLBACK_QUERY
297                        .execute_blocking(self)
298                        .map(|stmt| stmt.leak())
299                } else {
300                    Err(e)
301                }
302            }
303        };
304        self.set_typeinfo_enum(&stmt);
305        Ok(stmt)
306    }
307
308    fn typeinfo_composite_statement_blocking(&self) -> Result<Statement, Error> {
309        if let Some(stmt) = self.typeinfo_composite() {
310            return Ok(stmt);
311        }
312        let stmt = TYPEINFO_COMPOSITE_QUERY.execute_blocking(self)?.leak();
313        self.set_typeinfo_composite(&stmt);
314        Ok(stmt)
315    }
316}