use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::postgres::message::{ParameterDescription, RowDescription};
use crate::postgres::statement::PgStatementMetadata;
use crate::postgres::type_info::{PgCustomType, PgType, PgTypeKind};
use crate::postgres::{PgArguments, PgColumn, PgConnection, PgTypeInfo};
use crate::query_as::query_as;
use crate::query_scalar::{query_scalar, query_scalar_with};
use crate::types::Json;
use crate::HashMap;
use futures_core::future::BoxFuture;
use std::fmt::Write;
use std::sync::Arc;
impl PgConnection {
pub(super) async fn handle_row_description(
&mut self,
desc: Option<RowDescription>,
should_fetch: bool,
) -> Result<(Vec<PgColumn>, HashMap<UStr, usize>), Error> {
let mut columns = Vec::new();
let mut column_names = HashMap::new();
let desc = if let Some(desc) = desc {
desc
} else {
return Ok((columns, column_names));
};
columns.reserve(desc.fields.len());
column_names.reserve(desc.fields.len());
for (index, field) in desc.fields.into_iter().enumerate() {
let name = UStr::from(field.name);
let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
.await?;
let column = PgColumn {
ordinal: index,
name: name.clone(),
type_info,
relation_id: field.relation_id,
relation_attribute_no: field.relation_attribute_no,
};
columns.push(column);
column_names.insert(name, index);
}
Ok((columns, column_names))
}
pub(super) async fn handle_parameter_description(
&mut self,
desc: ParameterDescription,
) -> Result<Vec<PgTypeInfo>, Error> {
let mut params = Vec::with_capacity(desc.types.len());
for ty in desc.types {
params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?);
}
Ok(params)
}
async fn maybe_fetch_type_info_by_oid(
&mut self,
oid: u32,
should_fetch: bool,
) -> Result<PgTypeInfo, Error> {
if let Some(info) = PgTypeInfo::try_from_oid(oid) {
return Ok(info);
}
if let Some(info) = self.cache_type_info.get(&oid) {
return Ok(info.clone());
}
if should_fetch {
let info = self.fetch_type_by_oid(oid).await?;
self.cache_type_info.insert(oid, info.clone());
self.cache_type_oid
.insert(info.0.name().to_string().into(), oid);
Ok(info)
} else {
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
}
}
fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let (name, category, relation_id, element): (String, i8, u32, u32) = query_as(
"SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
match category as u8 {
b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
})))),
b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
})))),
b'R' => self.fetch_range_by_oid(oid, name).await,
b'E' => self.fetch_enum_by_oid(oid, name).await,
b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await,
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
name: name.into(),
oid,
})))),
}
})
}
async fn fetch_enum_by_oid(&mut self, oid: u32, name: String) -> Result<PgTypeInfo, Error> {
let variants: Vec<String> = query_scalar(
r#"
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY enumsortorder
"#,
)
.bind(oid)
.fetch_all(self)
.await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Enum(Arc::from(variants)),
}))))
}
fn fetch_composite_by_oid(
&mut self,
oid: u32,
relation_id: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let raw_fields: Vec<(String, u32)> = query_as(
r#"
SELECT attname, atttypid
FROM pg_catalog.pg_attribute
WHERE attrelid = $1
AND NOT attisdropped
AND attnum > 0
ORDER BY attnum
"#,
)
.bind(relation_id)
.fetch_all(&mut *self)
.await?;
let mut fields = Vec::new();
for (field_name, field_oid) in raw_fields.into_iter() {
let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
fields.push((field_name, field_type));
}
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Composite(Arc::from(fields)),
}))))
})
}
fn fetch_range_by_oid(
&mut self,
oid: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let element_oid: u32 = query_scalar(
r#"
SELECT rngsubtype
FROM pg_catalog.pg_range
WHERE rngtypid = $1
"#,
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(element),
name: name.into(),
oid,
}))))
})
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<u32, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
return Ok(*oid);
}
let (oid,): (u32,) = query_as(
"
SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
",
)
.bind(name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: String::from(name),
})?;
self.cache_type_oid.insert(name.to_string().into(), oid);
Ok(oid)
}
pub(crate) async fn get_nullable_for_columns(
&mut self,
stmt_id: u32,
meta: &PgStatementMetadata,
) -> Result<Vec<Option<bool>>, Error> {
if meta.columns.is_empty() {
return Ok(vec![]);
}
let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES ");
let mut args = PgArguments::default();
for (i, (column, bind)) in meta.columns.iter().zip((1..).step_by(3)).enumerate() {
if !args.buffer.is_empty() {
nullable_query += ", ";
}
let _ = write!(
nullable_query,
"(${}::int4, ${}::int4, ${}::int2)",
bind,
bind + 1,
bind + 2
);
args.add(i as i32);
args.add(column.relation_id);
args.add(column.relation_attribute_no);
}
nullable_query.push_str(
") as col(idx, table_id, col_idx) \
LEFT JOIN pg_catalog.pg_attribute \
ON table_id IS NOT NULL \
AND attrelid = table_id \
AND attnum = col_idx \
ORDER BY col.idx",
);
let mut nullables = query_scalar_with::<_, Option<bool>, _>(&nullable_query, args)
.fetch_all(&mut *self)
.await?;
let nullable_patch = self
.nullables_from_explain(stmt_id, meta.parameters.len())
.await?;
for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) {
*nullable = patch.or(*nullable);
}
Ok(nullables)
}
async fn nullables_from_explain(
&mut self,
stmt_id: u32,
params_len: usize,
) -> Result<Vec<Option<bool>>, Error> {
let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}", stmt_id);
let mut comma = false;
if params_len > 0 {
explain += "(";
for _ in 0..params_len {
if comma {
explain += ", ";
}
explain += "NULL";
comma = true;
}
explain += ")";
}
let (Json([explain]),): (Json<[Explain; 1]>,) = query_as(&explain).fetch_one(self).await?;
let mut nullables = Vec::new();
if let Some(outputs) = &explain.plan.output {
nullables.resize(outputs.len(), None);
visit_plan(&explain.plan, outputs, &mut nullables);
}
Ok(nullables)
}
}
fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec<Option<bool>>) {
if let Some(plan_outputs) = &plan.output {
if let Some("Full") | Some("Inner") = plan
.join_type
.as_deref()
.or(plan.parent_relation.as_deref())
{
for output in plan_outputs {
if let Some(i) = outputs.iter().position(|o| o == output) {
nullables[i] = Some(true);
}
}
}
}
if let Some(plans) = &plan.plans {
if let Some("Left") | Some("Right") = plan.join_type.as_deref() {
for plan in plans {
visit_plan(plan, outputs, nullables);
}
}
}
}
#[derive(serde::Deserialize)]
struct Explain {
#[serde(rename = "Plan")]
plan: Plan,
}
#[derive(serde::Deserialize)]
struct Plan {
#[serde(rename = "Join Type")]
join_type: Option<String>,
#[serde(rename = "Parent Relationship")]
parent_relation: Option<String>,
#[serde(rename = "Output")]
output: Option<Vec<String>>,
#[serde(rename = "Plans")]
plans: Option<Vec<Plan>>,
}