sqlx_postgres/connection/
describe.rs

1use crate::error::Error;
2use crate::ext::ustr::UStr;
3use crate::io::StatementId;
4use crate::message::{ParameterDescription, RowDescription};
5use crate::query_as::query_as;
6use crate::query_scalar::query_scalar;
7use crate::statement::PgStatementMetadata;
8use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind};
9use crate::types::Json;
10use crate::types::Oid;
11use crate::HashMap;
12use crate::{PgColumn, PgConnection, PgTypeInfo};
13use smallvec::SmallVec;
14use sqlx_core::query_builder::QueryBuilder;
15use std::sync::Arc;
16
17/// Describes the type of the `pg_type.typtype` column
18///
19/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
20#[derive(Copy, Clone, Debug, Eq, PartialEq)]
21enum TypType {
22    Base,
23    Composite,
24    Domain,
25    Enum,
26    Pseudo,
27    Range,
28}
29
30impl TryFrom<i8> for TypType {
31    type Error = ();
32
33    fn try_from(t: i8) -> Result<Self, Self::Error> {
34        let t = u8::try_from(t).or(Err(()))?;
35
36        let t = match t {
37            b'b' => Self::Base,
38            b'c' => Self::Composite,
39            b'd' => Self::Domain,
40            b'e' => Self::Enum,
41            b'p' => Self::Pseudo,
42            b'r' => Self::Range,
43            _ => return Err(()),
44        };
45        Ok(t)
46    }
47}
48
49/// Describes the type of the `pg_type.typcategory` column
50///
51/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
52#[derive(Copy, Clone, Debug, Eq, PartialEq)]
53enum TypCategory {
54    Array,
55    Boolean,
56    Composite,
57    DateTime,
58    Enum,
59    Geometric,
60    Network,
61    Numeric,
62    Pseudo,
63    Range,
64    String,
65    Timespan,
66    User,
67    BitString,
68    Unknown,
69}
70
71impl TryFrom<i8> for TypCategory {
72    type Error = ();
73
74    fn try_from(c: i8) -> Result<Self, Self::Error> {
75        let c = u8::try_from(c).or(Err(()))?;
76
77        let c = match c {
78            b'A' => Self::Array,
79            b'B' => Self::Boolean,
80            b'C' => Self::Composite,
81            b'D' => Self::DateTime,
82            b'E' => Self::Enum,
83            b'G' => Self::Geometric,
84            b'I' => Self::Network,
85            b'N' => Self::Numeric,
86            b'P' => Self::Pseudo,
87            b'R' => Self::Range,
88            b'S' => Self::String,
89            b'T' => Self::Timespan,
90            b'U' => Self::User,
91            b'V' => Self::BitString,
92            b'X' => Self::Unknown,
93            _ => return Err(()),
94        };
95        Ok(c)
96    }
97}
98
99impl PgConnection {
100    pub(super) async fn handle_row_description(
101        &mut self,
102        desc: Option<RowDescription>,
103        should_fetch: bool,
104    ) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
105        let mut columns = Vec::new();
106        let mut column_names = HashMap::new();
107
108        let desc = if let Some(desc) = desc {
109            desc
110        } else {
111            // no rows
112            return Ok((columns, column_names));
113        };
114
115        columns.reserve(desc.fields.len());
116        column_names.reserve(desc.fields.len());
117
118        for (index, field) in desc.fields.into_iter().enumerate() {
119            let name = UStr::from(field.name);
120
121            let type_info = self
122                .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
123                .await?;
124
125            let column = PgColumn {
126                ordinal: index,
127                name: name.clone(),
128                type_info,
129                relation_id: field.relation_id,
130                relation_attribute_no: field.relation_attribute_no,
131            };
132
133            columns.push(column);
134            column_names.insert(name, index);
135        }
136
137        Ok((columns, column_names))
138    }
139
140    pub(super) async fn handle_parameter_description(
141        &mut self,
142        desc: ParameterDescription,
143    ) -> Result<Vec<PgTypeInfo>, Error> {
144        let mut params = Vec::with_capacity(desc.types.len());
145
146        for ty in desc.types {
147            params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
148        }
149
150        Ok(params)
151    }
152
153    async fn maybe_fetch_type_info_by_oid(
154        &mut self,
155        oid: Oid,
156        should_fetch: bool,
157    ) -> Result<PgTypeInfo, Error> {
158        // first we check if this is a built-in type
159        // in the average application, the vast majority of checks should flow through this
160        if let Some(info) = PgTypeInfo::try_from_oid(oid) {
161            return Ok(info);
162        }
163
164        // next we check a local cache for user-defined type names <-> object id
165        if let Some(info) = self.inner.cache_type_info.get(&oid) {
166            return Ok(info.clone());
167        }
168
169        // fallback to asking the database directly for a type name
170        if should_fetch {
171            // we're boxing this future here so we can use async recursion
172            let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?;
173
174            // cache the type name <-> oid relationship in a paired hashmap
175            // so we don't come down this road again
176            self.inner.cache_type_info.insert(oid, info.clone());
177            self.inner
178                .cache_type_oid
179                .insert(info.0.name().to_string().into(), oid);
180
181            Ok(info)
182        } else {
183            // we are not in a place that *can* run a query
184            // this generally means we are in the middle of another query
185            // this _should_ only happen for complex types sent through the TEXT protocol
186            // we're open to ideas to correct this.. but it'd probably be more efficient to figure
187            // out a way to "prime" the type cache for connections rather than make this
188            // fallback work correctly for complex user-defined types for the TEXT protocol
189            Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
190        }
191    }
192
193    async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
194        let (name, typ_type, category, relation_id, element, base_type): (
195            String,
196            i8,
197            i8,
198            Oid,
199            Oid,
200            Oid,
201        ) = query_as(
202            // Converting the OID to `regtype` and then `text` will give us the name that
203            // the type will need to be found at by search_path.
204            "SELECT oid::regtype::text, \
205                     typtype, \
206                     typcategory, \
207                     typrelid, \
208                     typelem, \
209                     typbasetype \
210                     FROM pg_catalog.pg_type \
211                     WHERE oid = $1",
212        )
213        .bind(oid)
214        .fetch_one(&mut *self)
215        .await?;
216
217        let typ_type = TypType::try_from(typ_type);
218        let category = TypCategory::try_from(category);
219
220        match (typ_type, category) {
221            (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
222
223            (Ok(TypType::Base), Ok(TypCategory::Array)) => {
224                Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
225                    kind: PgTypeKind::Array(
226                        self.maybe_fetch_type_info_by_oid(element, true).await?,
227                    ),
228                    name: name.into(),
229                    oid,
230                }))))
231            }
232
233            (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
234                Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
235                    kind: PgTypeKind::Pseudo,
236                    name: name.into(),
237                    oid,
238                }))))
239            }
240
241            (Ok(TypType::Range), Ok(TypCategory::Range)) => {
242                self.fetch_range_by_oid(oid, name).await
243            }
244
245            (Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await,
246
247            (Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
248                self.fetch_composite_by_oid(oid, relation_id, name).await
249            }
250
251            _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
252                kind: PgTypeKind::Simple,
253                name: name.into(),
254                oid,
255            })))),
256        }
257    }
258
259    async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
260        let variants: Vec<String> = query_scalar(
261            r#"
262SELECT enumlabel
263FROM pg_catalog.pg_enum
264WHERE enumtypid = $1
265ORDER BY enumsortorder
266            "#,
267        )
268        .bind(oid)
269        .fetch_all(self)
270        .await?;
271
272        Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
273            oid,
274            name: name.into(),
275            kind: PgTypeKind::Enum(Arc::from(variants)),
276        }))))
277    }
278
279    async fn fetch_composite_by_oid(
280        &mut self,
281        oid: Oid,
282        relation_id: Oid,
283        name: String,
284    ) -> Result<PgTypeInfo, Error> {
285        let raw_fields: Vec<(String, Oid)> = query_as(
286            r#"
287SELECT attname, atttypid
288FROM pg_catalog.pg_attribute
289WHERE attrelid = $1
290AND NOT attisdropped
291AND attnum > 0
292ORDER BY attnum
293                "#,
294        )
295        .bind(relation_id)
296        .fetch_all(&mut *self)
297        .await?;
298
299        let mut fields = Vec::new();
300
301        for (field_name, field_oid) in raw_fields.into_iter() {
302            let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
303
304            fields.push((field_name, field_type));
305        }
306
307        Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
308            oid,
309            name: name.into(),
310            kind: PgTypeKind::Composite(Arc::from(fields)),
311        }))))
312    }
313
314    async fn fetch_domain_by_oid(
315        &mut self,
316        oid: Oid,
317        base_type: Oid,
318        name: String,
319    ) -> Result<PgTypeInfo, Error> {
320        let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
321
322        Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
323            oid,
324            name: name.into(),
325            kind: PgTypeKind::Domain(base_type),
326        }))))
327    }
328
329    async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
330        let element_oid: Oid = query_scalar(
331            r#"
332SELECT rngsubtype
333FROM pg_catalog.pg_range
334WHERE rngtypid = $1
335                "#,
336        )
337        .bind(oid)
338        .fetch_one(&mut *self)
339        .await?;
340
341        let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
342
343        Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
344            kind: PgTypeKind::Range(element),
345            name: name.into(),
346            oid,
347        }))))
348    }
349
350    pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {
351        if let Some(oid) = ty.try_oid() {
352            return Ok(oid);
353        }
354
355        match ty {
356            PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await,
357            PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await,
358            // `.try_oid()` should return `Some()` or it should be covered here
359            _ => unreachable!("(bug) OID should be resolvable for type {ty:?}"),
360        }
361    }
362
363    pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
364        if let Some(oid) = self.inner.cache_type_oid.get(name) {
365            return Ok(*oid);
366        }
367
368        // language=SQL
369        let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid")
370            .bind(name)
371            .fetch_optional(&mut *self)
372            .await?
373            .ok_or_else(|| Error::TypeNotFound {
374                type_name: name.into(),
375            })?;
376
377        self.inner
378            .cache_type_oid
379            .insert(name.to_string().into(), oid);
380        Ok(oid)
381    }
382
383    pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
384        if let Some(oid) = self
385            .inner
386            .cache_type_oid
387            .get(&array.elem_name)
388            .and_then(|elem_oid| self.inner.cache_elem_type_to_array.get(elem_oid))
389        {
390            return Ok(*oid);
391        }
392
393        // language=SQL
394        let (elem_oid, array_oid): (Oid, Oid) =
395            query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid")
396                .bind(&*array.elem_name)
397                .fetch_optional(&mut *self)
398                .await?
399                .ok_or_else(|| Error::TypeNotFound {
400                    type_name: array.name.to_string(),
401                })?;
402
403        // Avoids copying `elem_name` until necessary
404        self.inner
405            .cache_type_oid
406            .entry_ref(&array.elem_name)
407            .insert(elem_oid);
408        self.inner
409            .cache_elem_type_to_array
410            .insert(elem_oid, array_oid);
411
412        Ok(array_oid)
413    }
414
415    /// Check whether EXPLAIN statements are supported by the current connection
416    fn is_explain_available(&self) -> bool {
417        let parameter_statuses = &self.inner.stream.parameter_statuses;
418        let is_cockroachdb = parameter_statuses.contains_key("crdb_version");
419        let is_materialize = parameter_statuses.contains_key("mz_version");
420        let is_questdb = parameter_statuses.contains_key("questdb_version");
421        !is_cockroachdb && !is_materialize && !is_questdb
422    }
423
424    pub(crate) async fn get_nullable_for_columns(
425        &mut self,
426        stmt_id: StatementId,
427        meta: &PgStatementMetadata,
428    ) -> Result<Vec<Option<bool>>, Error> {
429        if meta.columns.is_empty() {
430            return Ok(vec![]);
431        }
432
433        if meta.columns.len() * 3 > 65535 {
434            tracing::debug!(
435                ?stmt_id,
436                num_columns = meta.columns.len(),
437                "number of columns in query is too large to pull nullability for"
438            );
439        }
440
441        // Query for NOT NULL constraints for each column in the query.
442        //
443        // This will include columns that don't have a `relation_id` (are not from a table);
444        // assuming those are a minority of columns, it's less code to _not_ work around it
445        // and just let Postgres return `NULL`.
446        //
447        // Use `UNION ALL` syntax instead of `VALUES` due to frequent lack of
448        // support for `VALUES` in pgwire supported databases.
449        let mut nullable_query = QueryBuilder::new("SELECT NOT attnotnull FROM ( ");
450        let mut separated = nullable_query.separated("UNION ALL ");
451
452        let mut column_iter = meta.columns.iter().zip(0i32..);
453        if let Some((column, i)) = column_iter.next() {
454            separated.push("( SELECT ");
455            separated
456                .push_bind_unseparated(i)
457                .push_unseparated("::int4 AS idx, ");
458            separated
459                .push_bind_unseparated(column.relation_id)
460                .push_unseparated("::int4 AS table_id, ");
461            separated
462                .push_bind_unseparated(column.relation_attribute_no)
463                .push_unseparated("::int2 AS col_idx ) ");
464        }
465
466        for (column, i) in column_iter {
467            separated.push("( SELECT ");
468            separated
469                .push_bind_unseparated(i)
470                .push_unseparated("::int4, ");
471            separated
472                .push_bind_unseparated(column.relation_id)
473                .push_unseparated("::int4, ");
474            separated
475                .push_bind_unseparated(column.relation_attribute_no)
476                .push_unseparated("::int2 ) ");
477        }
478
479        nullable_query.push(
480            ") AS col LEFT JOIN pg_catalog.pg_attribute \
481                ON table_id IS NOT NULL \
482               AND attrelid = table_id \
483               AND attnum = col_idx \
484            ORDER BY idx",
485        );
486
487        let mut nullables: Vec<Option<bool>> = nullable_query
488            .build_query_scalar()
489            .fetch_all(&mut *self)
490            .await
491            .map_err(|e| {
492                err_protocol!(
493                    "error from nullables query: {e}; query: {:?}",
494                    nullable_query.sql()
495                )
496            })?;
497
498        // If the server doesn't support EXPLAIN statements, skip this step (#1248).
499        if self.is_explain_available() {
500            // patch up our null inference with data from EXPLAIN
501            let nullable_patch = self
502                .nullables_from_explain(stmt_id, meta.parameters.len())
503                .await?;
504
505            for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
506                *nullable = patch.or(*nullable);
507            }
508        }
509
510        Ok(nullables)
511    }
512
513    /// Infer nullability for columns of this statement using EXPLAIN VERBOSE.
514    ///
515    /// This currently only marks columns that are on the inner half of an outer join
516    /// and returns `None` for all others.
517    async fn nullables_from_explain(
518        &mut self,
519        stmt_id: StatementId,
520        params_len: usize,
521    ) -> Result<Vec<Option<bool>>, Error> {
522        let stmt_id_display = stmt_id
523            .display()
524            .ok_or_else(|| err_protocol!("cannot EXPLAIN unnamed statement: {stmt_id:?}"))?;
525
526        let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id_display}");
527        let mut comma = false;
528
529        if params_len > 0 {
530            explain += "(";
531
532            // fill the arguments list with NULL, which should theoretically be valid
533            for _ in 0..params_len {
534                if comma {
535                    explain += ", ";
536                }
537
538                explain += "NULL";
539                comma = true;
540            }
541
542            explain += ")";
543        }
544
545        let (Json(explains),): (Json<SmallVec<[Explain; 1]>>,) =
546            query_as(&explain).fetch_one(self).await?;
547
548        let mut nullables = Vec::new();
549
550        if let Some(Explain::Plan {
551            plan:
552                plan @ Plan {
553                    output: Some(ref outputs),
554                    ..
555                },
556        }) = explains.first()
557        {
558            nullables.resize(outputs.len(), None);
559            visit_plan(plan, outputs, &mut nullables);
560        }
561
562        Ok(nullables)
563    }
564}
565
566fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
567    if let Some(plan_outputs) = &plan.output {
568        // all outputs of a Full Join must be marked nullable
569        // otherwise, all outputs of the inner half of an outer join must be marked nullable
570        if plan.join_type.as_deref() == Some("Full")
571            || plan.parent_relation.as_deref() == Some("Inner")
572        {
573            for output in plan_outputs {
574                if let Some(i) = outputs.iter().position(|o| o == output) {
575                    // N.B. this may produce false positives but those don't cause runtime errors
576                    nullables[i] = Some(true);
577                }
578            }
579        }
580    }
581
582    if let Some(plans) = &plan.plans {
583        if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
584            for plan in plans {
585                visit_plan(plan, outputs, nullables);
586            }
587        }
588    }
589}
590
591#[derive(serde::Deserialize, Debug)]
592#[serde(untagged)]
593enum Explain {
594    // NOTE: the returned JSON may not contain a `plan` field, for example, with `CALL` statements:
595    // https://github.com/launchbadge/sqlx/issues/1449
596    //
597    // In this case, we should just fall back to assuming all is nullable.
598    //
599    // It may also contain additional fields we don't care about, which should not break parsing:
600    // https://github.com/launchbadge/sqlx/issues/2587
601    // https://github.com/launchbadge/sqlx/issues/2622
602    Plan {
603        #[serde(rename = "Plan")]
604        plan: Plan,
605    },
606
607    // This ensures that parsing never technically fails.
608    //
609    // We don't want to specifically expect `"Utility Statement"` because there might be other cases
610    // and we don't care unless it contains a query plan anyway.
611    Other(serde::de::IgnoredAny),
612}
613
614#[derive(serde::Deserialize, Debug)]
615struct Plan {
616    #[serde(rename = "Join Type")]
617    join_type: Option<String>,
618    #[serde(rename = "Parent Relationship")]
619    parent_relation: Option<String>,
620    #[serde(rename = "Output")]
621    output: Option<Vec<String>>,
622    #[serde(rename = "Plans")]
623    plans: Option<Vec<Plan>>,
624}
625
626#[test]
627fn explain_parsing() {
628    let normal_plan = r#"[
629   {
630     "Plan": {
631       "Node Type": "Result",
632       "Parallel Aware": false,
633       "Async Capable": false,
634       "Startup Cost": 0.00,
635       "Total Cost": 0.01,
636       "Plan Rows": 1,
637       "Plan Width": 4,
638       "Output": ["1"]
639     }
640   }
641]"#;
642
643    // https://github.com/launchbadge/sqlx/issues/2622
644    let extra_field = r#"[
645   {                                        
646     "Plan": {                              
647       "Node Type": "Result",               
648       "Parallel Aware": false,             
649       "Async Capable": false,              
650       "Startup Cost": 0.00,                
651       "Total Cost": 0.01,                  
652       "Plan Rows": 1,                      
653       "Plan Width": 4,                     
654       "Output": ["1"]                      
655     },                                     
656     "Query Identifier": 1147616880456321454
657   }                                        
658]"#;
659
660    // https://github.com/launchbadge/sqlx/issues/1449
661    let utility_statement = r#"["Utility Statement"]"#;
662
663    let normal_plan_parsed = serde_json::from_str::<[Explain; 1]>(normal_plan).unwrap();
664    let extra_field_parsed = serde_json::from_str::<[Explain; 1]>(extra_field).unwrap();
665    let utility_statement_parsed = serde_json::from_str::<[Explain; 1]>(utility_statement).unwrap();
666
667    assert!(
668        matches!(normal_plan_parsed, [Explain::Plan { plan: Plan { .. } }]),
669        "unexpected parse from {normal_plan:?}: {normal_plan_parsed:?}"
670    );
671
672    assert!(
673        matches!(extra_field_parsed, [Explain::Plan { plan: Plan { .. } }]),
674        "unexpected parse from {extra_field:?}: {extra_field_parsed:?}"
675    );
676
677    assert!(
678        matches!(utility_statement_parsed, [Explain::Other(_)]),
679        "unexpected parse from {utility_statement:?}: {utility_statement_parsed:?}"
680    )
681}