sql_middleware/postgres/
query.rs

1use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
2use chrono::NaiveDateTime;
3use deadpool_postgres::Transaction;
4use serde_json::Value;
5use tokio_postgres::{Statement, types::ToSql};
6
7/// Build a result set from a Postgres query execution
8///
9/// # Errors
10/// Returns errors from query execution or result processing.
11pub async fn build_result_set(
12    stmt: &Statement,
13    params: &[&(dyn ToSql + Sync)],
14    transaction: &Transaction<'_>,
15) -> Result<ResultSet, SqlMiddlewareDbError> {
16    // Execute the query
17    let rows = transaction.query(stmt, params).await?;
18
19    let column_names: Vec<String> = stmt
20        .columns()
21        .iter()
22        .map(|col| col.name().to_string())
23        .collect();
24
25    // Preallocate capacity if we can estimate the number of rows
26    let capacity = rows.len();
27    let mut result_set = ResultSet::with_capacity(capacity);
28    // Store column names once in the result set
29    let column_names_rc = std::sync::Arc::new(column_names);
30    result_set.set_column_names(column_names_rc);
31
32    for row in rows {
33        let mut row_values = Vec::new();
34
35        let col_count = result_set
36            .get_column_names()
37            .ok_or_else(|| {
38                SqlMiddlewareDbError::ExecutionError("No column names available".to_string())
39            })?
40            .len();
41
42        for i in 0..col_count {
43            let value = postgres_extract_value(&row, i)?;
44            row_values.push(value);
45        }
46
47        result_set.add_row_values(row_values);
48    }
49
50    Ok(result_set)
51}
52
53/// Extracts a `RowValues` from a `tokio_postgres` Row at the given index
54fn postgres_extract_value(
55    row: &tokio_postgres::Row,
56    idx: usize,
57) -> Result<RowValues, SqlMiddlewareDbError> {
58    // Determine the type of the column and extract accordingly
59    let type_info = row.columns()[idx].type_();
60
61    // Match on the type based on PostgreSQL type OIDs or names
62    // For simplicity, we'll handle common types. You may need to expand this.
63    if type_info.name() == "int4" || type_info.name() == "int8" {
64        let val: Option<i64> = row.try_get(idx)?;
65        Ok(val.map_or(RowValues::Null, RowValues::Int))
66    } else if type_info.name() == "float4" || type_info.name() == "float8" {
67        let val: Option<f64> = row.try_get(idx)?;
68        Ok(val.map_or(RowValues::Null, RowValues::Float))
69    } else if type_info.name() == "bool" {
70        let val: Option<bool> = row.try_get(idx)?;
71        Ok(val.map_or(RowValues::Null, RowValues::Bool))
72    } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
73        let val: Option<NaiveDateTime> = row.try_get(idx)?;
74        Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
75    } else if type_info.name() == "json" || type_info.name() == "jsonb" {
76        let val: Option<Value> = row.try_get(idx)?;
77        Ok(val.map_or(RowValues::Null, RowValues::JSON))
78    } else if type_info.name() == "bytea" {
79        let val: Option<Vec<u8>> = row.try_get(idx)?;
80        Ok(val.map_or(RowValues::Null, RowValues::Blob))
81    } else if type_info.name() == "text"
82        || type_info.name() == "varchar"
83        || type_info.name() == "char"
84    {
85        let val: Option<String> = row.try_get(idx)?;
86        Ok(val.map_or(RowValues::Null, RowValues::Text))
87    } else {
88        // For other types, attempt to get as string
89        let val: Option<String> = row.try_get(idx)?;
90        Ok(val.map_or(RowValues::Null, RowValues::Text))
91    }
92}