1use crate::error::Error;
2use crate::ext::ustr::UStr;
3use crate::postgres::message::{ParameterDescription, RowDescription};
4use crate::postgres::statement::PgStatementMetadata;
5use crate::postgres::type_info::{PgCustomType, PgType, PgTypeKind};
6use crate::postgres::types::Oid;
7use crate::postgres::{PgArguments, PgColumn, PgConnection, PgTypeInfo};
8use crate::query_as::query_as;
9use crate::query_scalar::{query_scalar, query_scalar_with};
10use crate::types::Json;
11use crate::HashMap;
12use futures_core::future::BoxFuture;
13use std::fmt::Write;
14use std::sync::Arc;
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq)]
20enum TypType {
21 Base,
22 Composite,
23 Domain,
24 Enum,
25 Pseudo,
26 Range,
27}
28
29impl TryFrom<u8> for TypType {
30 type Error = ();
31
32 fn try_from(t: u8) -> Result<Self, Self::Error> {
33 let t = match t {
34 b'b' => Self::Base,
35 b'c' => Self::Composite,
36 b'd' => Self::Domain,
37 b'e' => Self::Enum,
38 b'p' => Self::Pseudo,
39 b'r' => Self::Range,
40 _ => return Err(()),
41 };
42 Ok(t)
43 }
44}
45
46#[derive(Copy, Clone, Debug, Eq, PartialEq)]
50enum TypCategory {
51 Array,
52 Boolean,
53 Composite,
54 DateTime,
55 Enum,
56 Geometric,
57 Network,
58 Numeric,
59 Pseudo,
60 Range,
61 String,
62 Timespan,
63 User,
64 BitString,
65 Unknown,
66}
67
68impl TryFrom<u8> for TypCategory {
69 type Error = ();
70
71 fn try_from(c: u8) -> Result<Self, Self::Error> {
72 let c = match c {
73 b'A' => Self::Array,
74 b'B' => Self::Boolean,
75 b'C' => Self::Composite,
76 b'D' => Self::DateTime,
77 b'E' => Self::Enum,
78 b'G' => Self::Geometric,
79 b'I' => Self::Network,
80 b'N' => Self::Numeric,
81 b'P' => Self::Pseudo,
82 b'R' => Self::Range,
83 b'S' => Self::String,
84 b'T' => Self::Timespan,
85 b'U' => Self::User,
86 b'V' => Self::BitString,
87 b'X' => Self::Unknown,
88 _ => return Err(()),
89 };
90 Ok(c)
91 }
92}
93
94impl PgConnection {
95 pub(super) async fn handle_row_description(
96 &mut self,
97 desc: Option<RowDescription>,
98 should_fetch: bool,
99 ) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
100 let mut columns = Vec::new();
101 let mut column_names = HashMap::new();
102
103 let desc = if let Some(desc) = desc {
104 desc
105 } else {
106 return Ok((columns, column_names));
108 };
109
110 columns.reserve(desc.fields.len());
111 column_names.reserve(desc.fields.len());
112
113 for (index, field) in desc.fields.into_iter().enumerate() {
114 let name = UStr::from(field.name);
115
116 let type_info = self
117 .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
118 .await?;
119
120 let column = PgColumn {
121 ordinal: index,
122 name: name.clone(),
123 type_info,
124 relation_id: field.relation_id,
125 relation_attribute_no: field.relation_attribute_no,
126 };
127
128 columns.push(column);
129 column_names.insert(name, index);
130 }
131
132 Ok((columns, column_names))
133 }
134
135 pub(super) async fn handle_parameter_description(
136 &mut self,
137 desc: ParameterDescription,
138 ) -> Result<Vec<PgTypeInfo>, Error> {
139 let mut params = Vec::with_capacity(desc.types.len());
140
141 for ty in desc.types {
142 params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
143 }
144
145 Ok(params)
146 }
147
148 async fn maybe_fetch_type_info_by_oid(
149 &mut self,
150 oid: Oid,
151 should_fetch: bool,
152 ) -> Result<PgTypeInfo, Error> {
153 if let Some(info) = PgTypeInfo::try_from_oid(oid) {
156 return Ok(info);
157 }
158
159 if let Some(info) = self.cache_type_info.get(&oid) {
161 return Ok(info.clone());
162 }
163
164 if should_fetch {
166 let info = self.fetch_type_by_oid(oid).await?;
167
168 self.cache_type_info.insert(oid, info.clone());
171 self.cache_type_oid
172 .insert(info.0.name().to_string().into(), oid);
173
174 Ok(info)
175 } else {
176 Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
183 }
184 }
185
186 fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
187 Box::pin(async move {
188 let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, Oid, Oid, Oid) = query_as(
189 "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
190 )
191 .bind(oid)
192 .fetch_one(&mut *self)
193 .await?;
194
195 let typ_type = TypType::try_from(typ_type as u8);
196 let category = TypCategory::try_from(category as u8);
197
198 match (typ_type, category) {
199 (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
200
201 (Ok(TypType::Base), Ok(TypCategory::Array)) => {
202 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
203 kind: PgTypeKind::Array(
204 self.maybe_fetch_type_info_by_oid(element, true).await?,
205 ),
206 name: name.into(),
207 oid,
208 }))))
209 }
210
211 (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
212 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
213 kind: PgTypeKind::Pseudo,
214 name: name.into(),
215 oid,
216 }))))
217 }
218
219 (Ok(TypType::Range), Ok(TypCategory::Range)) => {
220 self.fetch_range_by_oid(oid, name).await
221 }
222
223 (Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
224 self.fetch_enum_by_oid(oid, name).await
225 }
226
227 (Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
228 self.fetch_composite_by_oid(oid, relation_id, name).await
229 }
230
231 _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
232 kind: PgTypeKind::Simple,
233 name: name.into(),
234 oid,
235 })))),
236 }
237 })
238 }
239
240 async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
241 let variants: Vec<String> = query_scalar(
242 r#"
243SELECT enumlabel
244FROM pg_catalog.pg_enum
245WHERE enumtypid = $1
246ORDER BY enumsortorder
247 "#,
248 )
249 .bind(oid)
250 .fetch_all(self)
251 .await?;
252
253 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
254 oid,
255 name: name.into(),
256 kind: PgTypeKind::Enum(Arc::from(variants)),
257 }))))
258 }
259
260 fn fetch_composite_by_oid(
261 &mut self,
262 oid: Oid,
263 relation_id: Oid,
264 name: String,
265 ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
266 Box::pin(async move {
267 let raw_fields: Vec<(String, Oid)> = query_as(
268 r#"
269SELECT attname, atttypid
270FROM pg_catalog.pg_attribute
271WHERE attrelid = $1
272AND NOT attisdropped
273AND attnum > 0
274ORDER BY attnum
275 "#,
276 )
277 .bind(relation_id)
278 .fetch_all(&mut *self)
279 .await?;
280
281 let mut fields = Vec::new();
282
283 for (field_name, field_oid) in raw_fields.into_iter() {
284 let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
285
286 fields.push((field_name, field_type));
287 }
288
289 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
290 oid,
291 name: name.into(),
292 kind: PgTypeKind::Composite(Arc::from(fields)),
293 }))))
294 })
295 }
296
297 fn fetch_domain_by_oid(
298 &mut self,
299 oid: Oid,
300 base_type: Oid,
301 name: String,
302 ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
303 Box::pin(async move {
304 let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
305
306 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
307 oid,
308 name: name.into(),
309 kind: PgTypeKind::Domain(base_type),
310 }))))
311 })
312 }
313
314 fn fetch_range_by_oid(
315 &mut self,
316 oid: Oid,
317 name: String,
318 ) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
319 Box::pin(async move {
320 let element_oid: Oid = query_scalar(
321 r#"
322SELECT rngsubtype
323FROM pg_catalog.pg_range
324WHERE rngtypid = $1
325 "#,
326 )
327 .bind(oid)
328 .fetch_one(&mut *self)
329 .await?;
330
331 let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
332
333 Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
334 kind: PgTypeKind::Range(element),
335 name: name.into(),
336 oid,
337 }))))
338 })
339 }
340
341 pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
342 if let Some(oid) = self.cache_type_oid.get(name) {
343 return Ok(*oid);
344 }
345
346 let (oid,): (Oid,) = query_as(
348 "
349SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
350 ",
351 )
352 .bind(name)
353 .fetch_optional(&mut *self)
354 .await?
355 .ok_or_else(|| Error::TypeNotFound {
356 type_name: String::from(name),
357 })?;
358
359 self.cache_type_oid.insert(name.to_string().into(), oid);
360 Ok(oid)
361 }
362
363 pub(crate) async fn get_nullable_for_columns(
364 &mut self,
365 stmt_id: Oid,
366 meta: &PgStatementMetadata,
367 ) -> Result<Vec<Option<bool>>, Error> {
368 if meta.columns.is_empty() {
369 return Ok(vec![]);
370 }
371
372 let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES ");
373 let mut args = PgArguments::default();
374
375 for (i, (column, bind)) in meta.columns.iter().zip((1..).step_by(3)).enumerate() {
376 if !args.buffer.is_empty() {
377 nullable_query += ", ";
378 }
379
380 let _ = write!(
381 nullable_query,
382 "(${}::int4, ${}::int4, ${}::int2)",
383 bind,
384 bind + 1,
385 bind + 2
386 );
387
388 args.add(i as i32);
389 args.add(column.relation_id);
390 args.add(column.relation_attribute_no);
391 }
392
393 nullable_query.push_str(
394 ") as col(idx, table_id, col_idx) \
395 LEFT JOIN pg_catalog.pg_attribute \
396 ON table_id IS NOT NULL \
397 AND attrelid = table_id \
398 AND attnum = col_idx \
399 ORDER BY col.idx",
400 );
401
402 let mut nullables = query_scalar_with::<_, Option<bool>, _>(&nullable_query, args)
403 .fetch_all(&mut *self)
404 .await?;
405
406 if !self.stream.parameter_statuses.contains_key("crdb_version") {
408 let nullable_patch = self
410 .nullables_from_explain(stmt_id, meta.parameters.len())
411 .await?;
412
413 for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
414 *nullable = patch.or(*nullable);
415 }
416 }
417
418 Ok(nullables)
419 }
420
421 async fn nullables_from_explain(
426 &mut self,
427 stmt_id: Oid,
428 params_len: usize,
429 ) -> Result<Vec<Option<bool>>, Error> {
430 let mut explain = format!(
431 "EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
432 stmt_id.0
433 );
434 let mut comma = false;
435
436 if params_len > 0 {
437 explain += "(";
438
439 for _ in 0..params_len {
441 if comma {
442 explain += ", ";
443 }
444
445 explain += "NULL";
446 comma = true;
447 }
448
449 explain += ")";
450 }
451
452 let (Json([explain]),): (Json<[Explain; 1]>,) = query_as(&explain).fetch_one(self).await?;
453
454 let mut nullables = Vec::new();
455
456 if let Some(outputs) = &explain.plan.output {
457 nullables.resize(outputs.len(), None);
458 visit_plan(&explain.plan, outputs, &mut nullables);
459 }
460
461 Ok(nullables)
462 }
463}
464
465fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
466 if let Some(plan_outputs) = &plan.output {
467 if plan.join_type.as_deref() == Some("Full")
470 || plan.parent_relation.as_deref() == Some("Inner")
471 {
472 for output in plan_outputs {
473 if let Some(i) = outputs.iter().position(|o| o == output) {
474 nullables[i] = Some(true);
476 }
477 }
478 }
479 }
480
481 if let Some(plans) = &plan.plans {
482 if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
483 for plan in plans {
484 visit_plan(plan, outputs, nullables);
485 }
486 }
487 }
488}
489
490#[derive(serde::Deserialize)]
491struct Explain {
492 #[serde(rename = "Plan")]
493 plan: Plan,
494}
495
496#[derive(serde::Deserialize)]
497struct Plan {
498 #[serde(rename = "Join Type")]
499 join_type: Option<String>,
500 #[serde(rename = "Parent Relationship")]
501 parent_relation: Option<String>,
502 #[serde(rename = "Output")]
503 output: Option<Vec<String>>,
504 #[serde(rename = "Plans")]
505 plans: Option<Vec<Plan>>,
506}