sql_middleware/postgres/
query.rs

1use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
2use crate::types::{ConversionMode, ParamConverter};
3use chrono::NaiveDateTime;
4use serde_json::Value;
5use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
6
7use super::params::Params as PgParams;
8
9/// Build a result set from a Postgres query execution
10///
11/// # Errors
12/// Returns errors from query execution or result processing.
13pub async fn build_result_set(
14    stmt: &Statement,
15    params: &[&(dyn ToSql + Sync)],
16    transaction: &Transaction<'_>,
17) -> Result<ResultSet, SqlMiddlewareDbError> {
18    // Execute the query
19    let rows = transaction.query(stmt, params).await?;
20
21    let column_names: Vec<String> = stmt
22        .columns()
23        .iter()
24        .map(|col| col.name().to_string())
25        .collect();
26
27    // Preallocate capacity if we can estimate the number of rows
28    let capacity = rows.len();
29    let mut result_set = ResultSet::with_capacity(capacity);
30    // Store column names once in the result set
31    let column_names_rc = std::sync::Arc::new(column_names);
32    result_set.set_column_names(column_names_rc);
33
34    for row in rows {
35        let mut row_values = Vec::new();
36
37        let col_count = result_set
38            .get_column_names()
39            .ok_or_else(|| {
40                SqlMiddlewareDbError::ExecutionError("No column names available".to_string())
41            })?
42            .len();
43
44        for i in 0..col_count {
45            let value = postgres_extract_value(&row, i)?;
46            row_values.push(value);
47        }
48
49        result_set.add_row_values(row_values);
50    }
51
52    Ok(result_set)
53}
54
55/// Extracts a `RowValues` from a `tokio_postgres` Row at the given index.
56///
57/// # Errors
58/// Returns `SqlMiddlewareDbError` if the column cannot be retrieved.
59pub fn postgres_extract_value(
60    row: &tokio_postgres::Row,
61    idx: usize,
62) -> Result<RowValues, SqlMiddlewareDbError> {
63    // Determine the type of the column and extract accordingly
64    let type_info = row.columns()[idx].type_();
65
66    // Match on the type based on PostgreSQL type OIDs or names
67    // For simplicity, we'll handle common types. You may need to expand this.
68    if type_info.name() == "int2" {
69        let val: Option<i16> = row.try_get(idx)?;
70        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
71    } else if type_info.name() == "int4" {
72        let val: Option<i32> = row.try_get(idx)?;
73        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
74    } else if type_info.name() == "int8" {
75        let val: Option<i64> = row.try_get(idx)?;
76        Ok(val.map_or(RowValues::Null, RowValues::Int))
77    } else if type_info.name() == "float4" || type_info.name() == "float8" {
78        let val: Option<f64> = row.try_get(idx)?;
79        Ok(val.map_or(RowValues::Null, RowValues::Float))
80    } else if type_info.name() == "bool" {
81        let val: Option<bool> = row.try_get(idx)?;
82        Ok(val.map_or(RowValues::Null, RowValues::Bool))
83    } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
84        let val: Option<NaiveDateTime> = row.try_get(idx)?;
85        Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
86    } else if type_info.name() == "json" || type_info.name() == "jsonb" {
87        let val: Option<Value> = row.try_get(idx)?;
88        Ok(val.map_or(RowValues::Null, RowValues::JSON))
89    } else if type_info.name() == "bytea" {
90        let val: Option<Vec<u8>> = row.try_get(idx)?;
91        Ok(val.map_or(RowValues::Null, RowValues::Blob))
92    } else if type_info.name() == "text"
93        || type_info.name() == "varchar"
94        || type_info.name() == "char"
95    {
96        let val: Option<String> = row.try_get(idx)?;
97        Ok(val.map_or(RowValues::Null, RowValues::Text))
98    } else {
99        // For other types, attempt to get as string
100        let val: Option<String> = row.try_get(idx)?;
101        Ok(val.map_or(RowValues::Null, RowValues::Text))
102    }
103}
104
105/// Build a result set from raw Postgres rows (without a Transaction)
106///
107/// # Errors
108/// Returns errors from result processing.
109pub fn build_result_set_from_rows(
110    rows: &[tokio_postgres::Row],
111) -> Result<ResultSet, SqlMiddlewareDbError> {
112    let mut result_set = ResultSet::with_capacity(rows.len());
113    if let Some(row) = rows.first() {
114        let cols: Vec<String> = row.columns().iter().map(|c| c.name().to_string()).collect();
115        result_set.set_column_names(std::sync::Arc::new(cols));
116    }
117
118    for row in rows {
119        let col_count = row.columns().len();
120        let mut row_values = Vec::with_capacity(col_count);
121        for idx in 0..col_count {
122            row_values.push(postgres_extract_value(row, idx)?);
123        }
124        result_set.add_row_values(row_values);
125    }
126
127    Ok(result_set)
128}
129
130/// Execute a SELECT query on a client without managing transactions
131///
132/// # Errors
133/// Returns errors from parameter conversion or query execution.
134pub async fn execute_query_on_client(
135    client: &Client,
136    query: &str,
137    params: &[RowValues],
138) -> Result<ResultSet, SqlMiddlewareDbError> {
139    let converted = PgParams::convert_sql_params(params, ConversionMode::Query)?;
140    let rows = client
141        .query(query, converted.as_refs())
142        .await
143        .map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("postgres select error: {e}")))?;
144    build_result_set_from_rows(&rows)
145}
146
147/// Execute a DML query on a client without managing transactions
148///
149/// # Errors
150/// Returns errors from parameter conversion or query execution.
151pub async fn execute_dml_on_client(
152    client: &Client,
153    query: &str,
154    params: &[RowValues],
155    err_label: &str,
156) -> Result<usize, SqlMiddlewareDbError> {
157    let converted = PgParams::convert_sql_params(params, ConversionMode::Execute)?;
158    let rows = client
159        .execute(query, converted.as_refs())
160        .await
161        .map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))?;
162    usize::try_from(rows).map_err(|e| {
163        SqlMiddlewareDbError::ExecutionError(format!(
164            "postgres affected rows conversion error: {e}"
165        ))
166    })
167}