1use crate::error::Error;
2use crate::ext::ustr::UStr;
3use crate::io::StatementId;
4use crate::message::{ParameterDescription, RowDescription};
5use crate::query_as::query_as;
6use crate::query_scalar::query_scalar;
7use crate::statement::PgStatementMetadata;
8use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind};
9use crate::types::Json;
10use crate::types::Oid;
11use crate::HashMap;
12use crate::{PgColumn, PgConnection, PgTypeInfo};
13use smallvec::SmallVec;
14use sqlx_core::query_builder::QueryBuilder;
15use std::sync::Arc;
16
17#[derive(Copy, Clone, Debug, Eq, PartialEq)]
21enum TypType {
22 Base,
23 Composite,
24 Domain,
25 Enum,
26 Pseudo,
27 Range,
28}
29
30impl TryFrom<i8> for TypType {
31 type Error = ();
32
33 fn try_from(t: i8) -> Result<Self, Self::Error> {
34 let t = u8::try_from(t).or(Err(()))?;
35
36 let t = match t {
37 b'b' => Self::Base,
38 b'c' => Self::Composite,
39 b'd' => Self::Domain,
40 b'e' => Self::Enum,
41 b'p' => Self::Pseudo,
42 b'r' => Self::Range,
43 _ => return Err(()),
44 };
45 Ok(t)
46 }
47}
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq)]
53enum TypCategory {
54 Array,
55 Boolean,
56 Composite,
57 DateTime,
58 Enum,
59 Geometric,
60 Network,
61 Numeric,
62 Pseudo,
63 Range,
64 String,
65 Timespan,
66 User,
67 BitString,
68 Unknown,
69}
70
71impl TryFrom<i8> for TypCategory {
72 type Error = ();
73
74 fn try_from(c: i8) -> Result<Self, Self::Error> {
75 let c = u8::try_from(c).or(Err(()))?;
76
77 let c = match c {
78 b'A' => Self::Array,
79 b'B' => Self::Boolean,
80 b'C' => Self::Composite,
81 b'D' => Self::DateTime,
82 b'E' => Self::Enum,
83 b'G' => Self::Geometric,
84 b'I' => Self::Network,
85 b'N' => Self::Numeric,
86 b'P' => Self::Pseudo,
87 b'R' => Self::Range,
88 b'S' => Self::String,
89 b'T' => Self::Timespan,
90 b'U' => Self::User,
91 b'V' => Self::BitString,
92 b'X' => Self::Unknown,
93 _ => return Err(()),
94 };
95 Ok(c)
96 }
97}
98
99impl PgConnection {
100 pub(super) async fn handle_row_description(
101 &mut self,
102 desc: Option<RowDescription>,
103 should_fetch: bool,
104 ) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
105 let mut columns = Vec::new();
106 let mut column_names = HashMap::new();
107
108 let desc = if let Some(desc) = desc {
109 desc
110 } else {
111 return Ok((columns, column_names));
113 };
114
115 columns.reserve(desc.fields.len());
116 column_names.reserve(desc.fields.len());
117
118 for (index, field) in desc.fields.into_iter().enumerate() {
119 let name = UStr::from(field.name);
120
121 let type_info = self
122 .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
123 .await?;
124
125 let column = PgColumn {
126 ordinal: index,
127 name: name.clone(),
128 type_info,
129 relation_id: field.relation_id,
130 relation_attribute_no: field.relation_attribute_no,
131 };
132
133 columns.push(column);
134 column_names.insert(name, index);
135 }
136
137 Ok((columns, column_names))
138 }
139
140 pub(super) async fn handle_parameter_description(
141 &mut self,
142 desc: ParameterDescription,
143 ) -> Result<Vec<PgTypeInfo>, Error> {
144 let mut params = Vec::with_capacity(desc.types.len());
145
146 for ty in desc.types {
147 params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
148 }
149
150 Ok(params)
151 }
152
153 async fn maybe_fetch_type_info_by_oid(
154 &mut self,
155 oid: Oid,
156 should_fetch: bool,
157 ) -> Result<PgTypeInfo, Error> {
158 if let Some(info) = PgTypeInfo::try_from_oid(oid) {
161 return Ok(info);
162 }
163
164 if let Some(info) = self.inner.cache_type_info.get(&oid) {
166 return Ok(info.clone());
167 }
168
169 if should_fetch {
171 let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?;
173
174 self.inner.cache_type_info.insert(oid, info.clone());
177 self.inner
178 .cache_type_oid
179 .insert(info.0.name().to_string().into(), oid);
180
181 Ok(info)
182 } else {
183 Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
190 }
191 }
192
193 async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
194 let (name, typ_type, category, relation_id, element, base_type): (
195 String,
196 i8,
197 i8,
198 Oid,
199 Oid,
200 Oid,
201 ) = query_as(
202 "SELECT oid::regtype::text, \
205 typtype, \
206 typcategory, \
207 typrelid, \
208 typelem, \
209 typbasetype \
210 FROM pg_catalog.pg_type \
211 WHERE oid = $1",
212 )
213 .bind(oid)
214 .fetch_one(&mut *self)
215 .await?;
216
217 let typ_type = TypType::try_from(typ_type);
218 let category = TypCategory::try_from(category);
219
220 match (typ_type, category) {
221 (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
222
223 (Ok(TypType::Base), Ok(TypCategory::Array)) => {
224 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
225 kind: PgTypeKind::Array(
226 self.maybe_fetch_type_info_by_oid(element, true).await?,
227 ),
228 name: name.into(),
229 oid,
230 }))))
231 }
232
233 (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
234 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
235 kind: PgTypeKind::Pseudo,
236 name: name.into(),
237 oid,
238 }))))
239 }
240
241 (Ok(TypType::Range), Ok(TypCategory::Range)) => {
242 self.fetch_range_by_oid(oid, name).await
243 }
244
245 (Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await,
246
247 (Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
248 self.fetch_composite_by_oid(oid, relation_id, name).await
249 }
250
251 _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
252 kind: PgTypeKind::Simple,
253 name: name.into(),
254 oid,
255 })))),
256 }
257 }
258
259 async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
260 let variants: Vec<String> = query_scalar(
261 r#"
262SELECT enumlabel
263FROM pg_catalog.pg_enum
264WHERE enumtypid = $1
265ORDER BY enumsortorder
266 "#,
267 )
268 .bind(oid)
269 .fetch_all(self)
270 .await?;
271
272 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
273 oid,
274 name: name.into(),
275 kind: PgTypeKind::Enum(Arc::from(variants)),
276 }))))
277 }
278
279 async fn fetch_composite_by_oid(
280 &mut self,
281 oid: Oid,
282 relation_id: Oid,
283 name: String,
284 ) -> Result<PgTypeInfo, Error> {
285 let raw_fields: Vec<(String, Oid)> = query_as(
286 r#"
287SELECT attname, atttypid
288FROM pg_catalog.pg_attribute
289WHERE attrelid = $1
290AND NOT attisdropped
291AND attnum > 0
292ORDER BY attnum
293 "#,
294 )
295 .bind(relation_id)
296 .fetch_all(&mut *self)
297 .await?;
298
299 let mut fields = Vec::new();
300
301 for (field_name, field_oid) in raw_fields.into_iter() {
302 let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
303
304 fields.push((field_name, field_type));
305 }
306
307 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
308 oid,
309 name: name.into(),
310 kind: PgTypeKind::Composite(Arc::from(fields)),
311 }))))
312 }
313
314 async fn fetch_domain_by_oid(
315 &mut self,
316 oid: Oid,
317 base_type: Oid,
318 name: String,
319 ) -> Result<PgTypeInfo, Error> {
320 let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
321
322 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
323 oid,
324 name: name.into(),
325 kind: PgTypeKind::Domain(base_type),
326 }))))
327 }
328
329 async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
330 let element_oid: Oid = query_scalar(
331 r#"
332SELECT rngsubtype
333FROM pg_catalog.pg_range
334WHERE rngtypid = $1
335 "#,
336 )
337 .bind(oid)
338 .fetch_one(&mut *self)
339 .await?;
340
341 let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
342
343 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
344 kind: PgTypeKind::Range(element),
345 name: name.into(),
346 oid,
347 }))))
348 }
349
350 pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {
351 if let Some(oid) = ty.try_oid() {
352 return Ok(oid);
353 }
354
355 match ty {
356 PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await,
357 PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await,
358 _ => unreachable!("(bug) OID should be resolvable for type {ty:?}"),
360 }
361 }
362
363 pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
364 if let Some(oid) = self.inner.cache_type_oid.get(name) {
365 return Ok(*oid);
366 }
367
368 let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid")
370 .bind(name)
371 .fetch_optional(&mut *self)
372 .await?
373 .ok_or_else(|| Error::TypeNotFound {
374 type_name: name.into(),
375 })?;
376
377 self.inner
378 .cache_type_oid
379 .insert(name.to_string().into(), oid);
380 Ok(oid)
381 }
382
383 pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
384 if let Some(oid) = self
385 .inner
386 .cache_type_oid
387 .get(&array.elem_name)
388 .and_then(|elem_oid| self.inner.cache_elem_type_to_array.get(elem_oid))
389 {
390 return Ok(*oid);
391 }
392
393 let (elem_oid, array_oid): (Oid, Oid) =
395 query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid")
396 .bind(&*array.elem_name)
397 .fetch_optional(&mut *self)
398 .await?
399 .ok_or_else(|| Error::TypeNotFound {
400 type_name: array.name.to_string(),
401 })?;
402
403 self.inner
405 .cache_type_oid
406 .entry_ref(&array.elem_name)
407 .insert(elem_oid);
408 self.inner
409 .cache_elem_type_to_array
410 .insert(elem_oid, array_oid);
411
412 Ok(array_oid)
413 }
414
415 fn is_explain_available(&self) -> bool {
417 let parameter_statuses = &self.inner.stream.parameter_statuses;
418 let is_cockroachdb = parameter_statuses.contains_key("crdb_version");
419 let is_materialize = parameter_statuses.contains_key("mz_version");
420 let is_questdb = parameter_statuses.contains_key("questdb_version");
421 !is_cockroachdb && !is_materialize && !is_questdb
422 }
423
424 pub(crate) async fn get_nullable_for_columns(
425 &mut self,
426 stmt_id: StatementId,
427 meta: &PgStatementMetadata,
428 ) -> Result<Vec<Option<bool>>, Error> {
429 if meta.columns.is_empty() {
430 return Ok(vec![]);
431 }
432
433 if meta.columns.len() * 3 > 65535 {
434 tracing::debug!(
435 ?stmt_id,
436 num_columns = meta.columns.len(),
437 "number of columns in query is too large to pull nullability for"
438 );
439 }
440
441 let mut nullable_query = QueryBuilder::new("SELECT NOT attnotnull FROM ( ");
450 let mut separated = nullable_query.separated("UNION ALL ");
451
452 let mut column_iter = meta.columns.iter().zip(0i32..);
453 if let Some((column, i)) = column_iter.next() {
454 separated.push("( SELECT ");
455 separated
456 .push_bind_unseparated(i)
457 .push_unseparated("::int4 AS idx, ");
458 separated
459 .push_bind_unseparated(column.relation_id)
460 .push_unseparated("::int4 AS table_id, ");
461 separated
462 .push_bind_unseparated(column.relation_attribute_no)
463 .push_unseparated("::int2 AS col_idx ) ");
464 }
465
466 for (column, i) in column_iter {
467 separated.push("( SELECT ");
468 separated
469 .push_bind_unseparated(i)
470 .push_unseparated("::int4, ");
471 separated
472 .push_bind_unseparated(column.relation_id)
473 .push_unseparated("::int4, ");
474 separated
475 .push_bind_unseparated(column.relation_attribute_no)
476 .push_unseparated("::int2 ) ");
477 }
478
479 nullable_query.push(
480 ") AS col LEFT JOIN pg_catalog.pg_attribute \
481 ON table_id IS NOT NULL \
482 AND attrelid = table_id \
483 AND attnum = col_idx \
484 ORDER BY idx",
485 );
486
487 let mut nullables: Vec<Option<bool>> = nullable_query
488 .build_query_scalar()
489 .fetch_all(&mut *self)
490 .await
491 .map_err(|e| {
492 err_protocol!(
493 "error from nullables query: {e}; query: {:?}",
494 nullable_query.sql()
495 )
496 })?;
497
498 if self.is_explain_available() {
500 let nullable_patch = self
502 .nullables_from_explain(stmt_id, meta.parameters.len())
503 .await?;
504
505 for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
506 *nullable = patch.or(*nullable);
507 }
508 }
509
510 Ok(nullables)
511 }
512
513 async fn nullables_from_explain(
518 &mut self,
519 stmt_id: StatementId,
520 params_len: usize,
521 ) -> Result<Vec<Option<bool>>, Error> {
522 let stmt_id_display = stmt_id
523 .display()
524 .ok_or_else(|| err_protocol!("cannot EXPLAIN unnamed statement: {stmt_id:?}"))?;
525
526 let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id_display}");
527 let mut comma = false;
528
529 if params_len > 0 {
530 explain += "(";
531
532 for _ in 0..params_len {
534 if comma {
535 explain += ", ";
536 }
537
538 explain += "NULL";
539 comma = true;
540 }
541
542 explain += ")";
543 }
544
545 let (Json(explains),): (Json<SmallVec<[Explain; 1]>>,) =
546 query_as(&explain).fetch_one(self).await?;
547
548 let mut nullables = Vec::new();
549
550 if let Some(Explain::Plan {
551 plan:
552 plan @ Plan {
553 output: Some(ref outputs),
554 ..
555 },
556 }) = explains.first()
557 {
558 nullables.resize(outputs.len(), None);
559 visit_plan(plan, outputs, &mut nullables);
560 }
561
562 Ok(nullables)
563 }
564}
565
566fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
567 if let Some(plan_outputs) = &plan.output {
568 if plan.join_type.as_deref() == Some("Full")
571 || plan.parent_relation.as_deref() == Some("Inner")
572 {
573 for output in plan_outputs {
574 if let Some(i) = outputs.iter().position(|o| o == output) {
575 nullables[i] = Some(true);
577 }
578 }
579 }
580 }
581
582 if let Some(plans) = &plan.plans {
583 if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
584 for plan in plans {
585 visit_plan(plan, outputs, nullables);
586 }
587 }
588 }
589}
590
591#[derive(serde::Deserialize, Debug)]
592#[serde(untagged)]
593enum Explain {
594 Plan {
603 #[serde(rename = "Plan")]
604 plan: Plan,
605 },
606
607 Other(serde::de::IgnoredAny),
612}
613
614#[derive(serde::Deserialize, Debug)]
615struct Plan {
616 #[serde(rename = "Join Type")]
617 join_type: Option<String>,
618 #[serde(rename = "Parent Relationship")]
619 parent_relation: Option<String>,
620 #[serde(rename = "Output")]
621 output: Option<Vec<String>>,
622 #[serde(rename = "Plans")]
623 plans: Option<Vec<Plan>>,
624}
625
626#[test]
627fn explain_parsing() {
628 let normal_plan = r#"[
629 {
630 "Plan": {
631 "Node Type": "Result",
632 "Parallel Aware": false,
633 "Async Capable": false,
634 "Startup Cost": 0.00,
635 "Total Cost": 0.01,
636 "Plan Rows": 1,
637 "Plan Width": 4,
638 "Output": ["1"]
639 }
640 }
641]"#;
642
643 let extra_field = r#"[
645 {
646 "Plan": {
647 "Node Type": "Result",
648 "Parallel Aware": false,
649 "Async Capable": false,
650 "Startup Cost": 0.00,
651 "Total Cost": 0.01,
652 "Plan Rows": 1,
653 "Plan Width": 4,
654 "Output": ["1"]
655 },
656 "Query Identifier": 1147616880456321454
657 }
658]"#;
659
660 let utility_statement = r#"["Utility Statement"]"#;
662
663 let normal_plan_parsed = serde_json::from_str::<[Explain; 1]>(normal_plan).unwrap();
664 let extra_field_parsed = serde_json::from_str::<[Explain; 1]>(extra_field).unwrap();
665 let utility_statement_parsed = serde_json::from_str::<[Explain; 1]>(utility_statement).unwrap();
666
667 assert!(
668 matches!(normal_plan_parsed, [Explain::Plan { plan: Plan { .. } }]),
669 "unexpected parse from {normal_plan:?}: {normal_plan_parsed:?}"
670 );
671
672 assert!(
673 matches!(extra_field_parsed, [Explain::Plan { plan: Plan { .. } }]),
674 "unexpected parse from {extra_field:?}: {extra_field_parsed:?}"
675 );
676
677 assert!(
678 matches!(utility_statement_parsed, [Explain::Other(_)]),
679 "unexpected parse from {utility_statement:?}: {utility_statement_parsed:?}"
680 )
681}