sqlx_core/postgres/connection/
describe.rs

1use crate::error::Error;
2use crate::ext::ustr::UStr;
3use crate::postgres::message::{ParameterDescription, RowDescription};
4use crate::postgres::statement::PgStatementMetadata;
5use crate::postgres::type_info::{PgCustomType, PgType, PgTypeKind};
6use crate::postgres::types::Oid;
7use crate::postgres::{PgArguments, PgColumn, PgConnection, PgTypeInfo};
8use crate::query_as::query_as;
9use crate::query_scalar::{query_scalar, query_scalar_with};
10use crate::types::Json;
11use crate::HashMap;
12use futures_core::future::BoxFuture;
13use std::fmt::Write;
14use std::sync::Arc;
15
16/// Describes the type of the `pg_type.typtype` column
17///
18/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
19#[derive(Copy, Clone, Debug, Eq, PartialEq)]
20enum TypType {
21    Base,
22    Composite,
23    Domain,
24    Enum,
25    Pseudo,
26    Range,
27}
28
29impl TryFrom<u8> for TypType {
30    type Error = ();
31
32    fn try_from(t: u8) -> Result<Self, Self::Error> {
33        let t = match t {
34            b'b' => Self::Base,
35            b'c' => Self::Composite,
36            b'd' => Self::Domain,
37            b'e' => Self::Enum,
38            b'p' => Self::Pseudo,
39            b'r' => Self::Range,
40            _ => return Err(()),
41        };
42        Ok(t)
43    }
44}
45
46/// Describes the type of the `pg_type.typcategory` column
47///
48/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
49#[derive(Copy, Clone, Debug, Eq, PartialEq)]
50enum TypCategory {
51    Array,
52    Boolean,
53    Composite,
54    DateTime,
55    Enum,
56    Geometric,
57    Network,
58    Numeric,
59    Pseudo,
60    Range,
61    String,
62    Timespan,
63    User,
64    BitString,
65    Unknown,
66}
67
68impl TryFrom<u8> for TypCategory {
69    type Error = ();
70
71    fn try_from(c: u8) -> Result<Self, Self::Error> {
72        let c = match c {
73            b'A' => Self::Array,
74            b'B' => Self::Boolean,
75            b'C' => Self::Composite,
76            b'D' => Self::DateTime,
77            b'E' => Self::Enum,
78            b'G' => Self::Geometric,
79            b'I' => Self::Network,
80            b'N' => Self::Numeric,
81            b'P' => Self::Pseudo,
82            b'R' => Self::Range,
83            b'S' => Self::String,
84            b'T' => Self::Timespan,
85            b'U' => Self::User,
86            b'V' => Self::BitString,
87            b'X' => Self::Unknown,
88            _ => return Err(()),
89        };
90        Ok(c)
91    }
92}
93
94impl PgConnection {
95    pub(super) async fn handle_row_description(
96        &mut self,
97        desc: Option<RowDescription>,
98        should_fetch: bool,
99    ) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
100        let mut columns = Vec::new();
101        let mut column_names = HashMap::new();
102
103        let desc = if let Some(desc) = desc {
104            desc
105        } else {
106            // no rows
107            return Ok((columns, column_names));
108        };
109
110        columns.reserve(desc.fields.len());
111        column_names.reserve(desc.fields.len());
112
113        for (index, field) in desc.fields.into_iter().enumerate() {
114            let name = UStr::from(field.name);
115
116            let type_info = self
117                .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
118                .await?;
119
120            let column = PgColumn {
121                ordinal: index,
122                name: name.clone(),
123                type_info,
124                relation_id: field.relation_id,
125                relation_attribute_no: field.relation_attribute_no,
126            };
127
128            columns.push(column);
129            column_names.insert(name, index);
130        }
131
132        Ok((columns, column_names))
133    }
134
135    pub(super) async fn handle_parameter_description(
136        &mut self,
137        desc: ParameterDescription,
138    ) -> Result<Vec<PgTypeInfo>, Error> {
139        let mut params = Vec::with_capacity(desc.types.len());
140
141        for ty in desc.types {
142            params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
143        }
144
145        Ok(params)
146    }
147
148    async fn maybe_fetch_type_info_by_oid(
149        &mut self,
150        oid: Oid,
151        should_fetch: bool,
152    ) -> Result<PgTypeInfo, Error> {
153        // first we check if this is a built-in type
154        // in the average application, the vast majority of checks should flow through this
155        if let Some(info) = PgTypeInfo::try_from_oid(oid) {
156            return Ok(info);
157        }
158
159        // next we check a local cache for user-defined type names <-> object id
160        if let Some(info) = self.cache_type_info.get(&oid) {
161            return Ok(info.clone());
162        }
163
164        // fallback to asking the database directly for a type name
165        if should_fetch {
166            let info = self.fetch_type_by_oid(oid).await?;
167
168            // cache the type name <-> oid relationship in a paired hashmap
169            // so we don't come down this road again
170            self.cache_type_info.insert(oid, info.clone());
171            self.cache_type_oid
172                .insert(info.0.name().to_string().into(), oid);
173
174            Ok(info)
175        } else {
176            // we are not in a place that *can* run a query
177            // this generally means we are in the middle of another query
178            // this _should_ only happen for complex types sent through the TEXT protocol
179            // we're open to ideas to correct this.. but it'd probably be more efficient to figure
180            // out a way to "prime" the type cache for connections rather than make this
181            // fallback work correctly for complex user-defined types for the TEXT protocol
182            Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
183        }
184    }
185
186    fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
187        Box::pin(async move {
188            let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, Oid, Oid, Oid) = query_as(
189                "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
190            )
191            .bind(oid)
192            .fetch_one(&mut *self)
193            .await?;
194
195            let typ_type = TypType::try_from(typ_type as u8);
196            let category = TypCategory::try_from(category as u8);
197
198            match (typ_type, category) {
199                (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
200
201                (Ok(TypType::Base), Ok(TypCategory::Array)) => {
202                    Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
203                        kind: PgTypeKind::Array(
204                            self.maybe_fetch_type_info_by_oid(element, true).await?,
205                        ),
206                        name: name.into(),
207                        oid,
208                    }))))
209                }
210
211                (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
212                    Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
213                        kind: PgTypeKind::Pseudo,
214                        name: name.into(),
215                        oid,
216                    }))))
217                }
218
219                (Ok(TypType::Range), Ok(TypCategory::Range)) => {
220                    self.fetch_range_by_oid(oid, name).await
221                }
222
223                (Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
224                    self.fetch_enum_by_oid(oid, name).await
225                }
226
227                (Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
228                    self.fetch_composite_by_oid(oid, relation_id, name).await
229                }
230
231                _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
232                    kind: PgTypeKind::Simple,
233                    name: name.into(),
234                    oid,
235                })))),
236            }
237        })
238    }
239
240    async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
241        let variants: Vec<String> = query_scalar(
242            r#"
243SELECT enumlabel
244FROM pg_catalog.pg_enum
245WHERE enumtypid = $1
246ORDER BY enumsortorder
247            "#,
248        )
249        .bind(oid)
250        .fetch_all(self)
251        .await?;
252
253        Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
254            oid,
255            name: name.into(),
256            kind: PgTypeKind::Enum(Arc::from(variants)),
257        }))))
258    }
259
260    fn fetch_composite_by_oid(
261        &mut self,
262        oid: Oid,
263        relation_id: Oid,
264        name: String,
265    ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
266        Box::pin(async move {
267            let raw_fields: Vec<(String, Oid)> = query_as(
268                r#"
269SELECT attname, atttypid
270FROM pg_catalog.pg_attribute
271WHERE attrelid = $1
272AND NOT attisdropped
273AND attnum > 0
274ORDER BY attnum
275                "#,
276            )
277            .bind(relation_id)
278            .fetch_all(&mut *self)
279            .await?;
280
281            let mut fields = Vec::new();
282
283            for (field_name, field_oid) in raw_fields.into_iter() {
284                let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
285
286                fields.push((field_name, field_type));
287            }
288
289            Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
290                oid,
291                name: name.into(),
292                kind: PgTypeKind::Composite(Arc::from(fields)),
293            }))))
294        })
295    }
296
297    fn fetch_domain_by_oid(
298        &mut self,
299        oid: Oid,
300        base_type: Oid,
301        name: String,
302    ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
303        Box::pin(async move {
304            let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
305
306            Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
307                oid,
308                name: name.into(),
309                kind: PgTypeKind::Domain(base_type),
310            }))))
311        })
312    }
313
314    fn fetch_range_by_oid(
315        &mut self,
316        oid: Oid,
317        name: String,
318    ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
319        Box::pin(async move {
320            let element_oid: Oid = query_scalar(
321                r#"
322SELECT rngsubtype
323FROM pg_catalog.pg_range
324WHERE rngtypid = $1
325                "#,
326            )
327            .bind(oid)
328            .fetch_one(&mut *self)
329            .await?;
330
331            let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
332
333            Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
334                kind: PgTypeKind::Range(element),
335                name: name.into(),
336                oid,
337            }))))
338        })
339    }
340
341    pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
342        if let Some(oid) = self.cache_type_oid.get(name) {
343            return Ok(*oid);
344        }
345
346        // language=SQL
347        let (oid,): (Oid,) = query_as(
348            "
349SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
350                ",
351        )
352        .bind(name)
353        .fetch_optional(&mut *self)
354        .await?
355        .ok_or_else(|| Error::TypeNotFound {
356            type_name: String::from(name),
357        })?;
358
359        self.cache_type_oid.insert(name.to_string().into(), oid);
360        Ok(oid)
361    }
362
363    pub(crate) async fn get_nullable_for_columns(
364        &mut self,
365        stmt_id: Oid,
366        meta: &PgStatementMetadata,
367    ) -> Result<Vec<Option<bool>>, Error> {
368        if meta.columns.is_empty() {
369            return Ok(vec![]);
370        }
371
372        let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES ");
373        let mut args = PgArguments::default();
374
375        for (i, (column, bind)) in meta.columns.iter().zip((1..).step_by(3)).enumerate() {
376            if !args.buffer.is_empty() {
377                nullable_query += ", ";
378            }
379
380            let _ = write!(
381                nullable_query,
382                "(${}::int4, ${}::int4, ${}::int2)",
383                bind,
384                bind + 1,
385                bind + 2
386            );
387
388            args.add(i as i32);
389            args.add(column.relation_id);
390            args.add(column.relation_attribute_no);
391        }
392
393        nullable_query.push_str(
394            ") as col(idx, table_id, col_idx) \
395            LEFT JOIN pg_catalog.pg_attribute \
396                ON table_id IS NOT NULL \
397               AND attrelid = table_id \
398               AND attnum = col_idx \
399            ORDER BY col.idx",
400        );
401
402        let mut nullables = query_scalar_with::<_, Option<bool>, _>(&nullable_query, args)
403            .fetch_all(&mut *self)
404            .await?;
405
406        // if it's cockroachdb skip this step #1248
407        if !self.stream.parameter_statuses.contains_key("crdb_version") {
408            // patch up our null inference with data from EXPLAIN
409            let nullable_patch = self
410                .nullables_from_explain(stmt_id, meta.parameters.len())
411                .await?;
412
413            for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
414                *nullable = patch.or(*nullable);
415            }
416        }
417
418        Ok(nullables)
419    }
420
421    /// Infer nullability for columns of this statement using EXPLAIN VERBOSE.
422    ///
423    /// This currently only marks columns that are on the inner half of an outer join
424    /// and returns `None` for all others.
425    async fn nullables_from_explain(
426        &mut self,
427        stmt_id: Oid,
428        params_len: usize,
429    ) -> Result<Vec<Option<bool>>, Error> {
430        let mut explain = format!(
431            "EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
432            stmt_id.0
433        );
434        let mut comma = false;
435
436        if params_len > 0 {
437            explain += "(";
438
439            // fill the arguments list with NULL, which should theoretically be valid
440            for _ in 0..params_len {
441                if comma {
442                    explain += ", ";
443                }
444
445                explain += "NULL";
446                comma = true;
447            }
448
449            explain += ")";
450        }
451
452        let (Json([explain]),): (Json<[Explain; 1]>,) = query_as(&explain).fetch_one(self).await?;
453
454        let mut nullables = Vec::new();
455
456        if let Some(outputs) = &explain.plan.output {
457            nullables.resize(outputs.len(), None);
458            visit_plan(&explain.plan, outputs, &mut nullables);
459        }
460
461        Ok(nullables)
462    }
463}
464
465fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
466    if let Some(plan_outputs) = &plan.output {
467        // all outputs of a Full Join must be marked nullable
468        // otherwise, all outputs of the inner half of an outer join must be marked nullable
469        if plan.join_type.as_deref() == Some("Full")
470            || plan.parent_relation.as_deref() == Some("Inner")
471        {
472            for output in plan_outputs {
473                if let Some(i) = outputs.iter().position(|o| o == output) {
474                    // N.B. this may produce false positives but those don't cause runtime errors
475                    nullables[i] = Some(true);
476                }
477            }
478        }
479    }
480
481    if let Some(plans) = &plan.plans {
482        if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
483            for plan in plans {
484                visit_plan(plan, outputs, nullables);
485            }
486        }
487    }
488}
489
490#[derive(serde::Deserialize)]
491struct Explain {
492    #[serde(rename = "Plan")]
493    plan: Plan,
494}
495
496#[derive(serde::Deserialize)]
497struct Plan {
498    #[serde(rename = "Join Type")]
499    join_type: Option<String>,
500    #[serde(rename = "Parent Relationship")]
501    parent_relation: Option<String>,
502    #[serde(rename = "Output")]
503    output: Option<Vec<String>>,
504    #[serde(rename = "Plans")]
505    plans: Option<Vec<Plan>>,
506}