sql_middleware/postgres/
query.rs1use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
2use crate::types::{ConversionMode, ParamConverter};
3use chrono::NaiveDateTime;
4use serde_json::Value;
5use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
6
7use super::params::Params as PgParams;
8
9pub async fn build_result_set(
14 stmt: &Statement,
15 params: &[&(dyn ToSql + Sync)],
16 transaction: &Transaction<'_>,
17) -> Result<ResultSet, SqlMiddlewareDbError> {
18 let rows = transaction.query(stmt, params).await?;
20
21 let column_names: Vec<String> = stmt
22 .columns()
23 .iter()
24 .map(|col| col.name().to_string())
25 .collect();
26
27 let capacity = rows.len();
29 let mut result_set = ResultSet::with_capacity(capacity);
30 let column_names_rc = std::sync::Arc::new(column_names);
32 result_set.set_column_names(column_names_rc);
33
34 for row in rows {
35 let mut row_values = Vec::new();
36
37 let col_count = result_set
38 .get_column_names()
39 .ok_or_else(|| {
40 SqlMiddlewareDbError::ExecutionError("No column names available".to_string())
41 })?
42 .len();
43
44 for i in 0..col_count {
45 let value = postgres_extract_value(&row, i)?;
46 row_values.push(value);
47 }
48
49 result_set.add_row_values(row_values);
50 }
51
52 Ok(result_set)
53}
54
55pub fn postgres_extract_value(
60 row: &tokio_postgres::Row,
61 idx: usize,
62) -> Result<RowValues, SqlMiddlewareDbError> {
63 let type_info = row.columns()[idx].type_();
65
66 if type_info.name() == "int2" {
69 let val: Option<i16> = row.try_get(idx)?;
70 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
71 } else if type_info.name() == "int4" {
72 let val: Option<i32> = row.try_get(idx)?;
73 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
74 } else if type_info.name() == "int8" {
75 let val: Option<i64> = row.try_get(idx)?;
76 Ok(val.map_or(RowValues::Null, RowValues::Int))
77 } else if type_info.name() == "float4" || type_info.name() == "float8" {
78 let val: Option<f64> = row.try_get(idx)?;
79 Ok(val.map_or(RowValues::Null, RowValues::Float))
80 } else if type_info.name() == "bool" {
81 let val: Option<bool> = row.try_get(idx)?;
82 Ok(val.map_or(RowValues::Null, RowValues::Bool))
83 } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
84 let val: Option<NaiveDateTime> = row.try_get(idx)?;
85 Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
86 } else if type_info.name() == "json" || type_info.name() == "jsonb" {
87 let val: Option<Value> = row.try_get(idx)?;
88 Ok(val.map_or(RowValues::Null, RowValues::JSON))
89 } else if type_info.name() == "bytea" {
90 let val: Option<Vec<u8>> = row.try_get(idx)?;
91 Ok(val.map_or(RowValues::Null, RowValues::Blob))
92 } else if type_info.name() == "text"
93 || type_info.name() == "varchar"
94 || type_info.name() == "char"
95 {
96 let val: Option<String> = row.try_get(idx)?;
97 Ok(val.map_or(RowValues::Null, RowValues::Text))
98 } else {
99 let val: Option<String> = row.try_get(idx)?;
101 Ok(val.map_or(RowValues::Null, RowValues::Text))
102 }
103}
104
105pub fn build_result_set_from_rows(
110 rows: &[tokio_postgres::Row],
111) -> Result<ResultSet, SqlMiddlewareDbError> {
112 let mut result_set = ResultSet::with_capacity(rows.len());
113 if let Some(row) = rows.first() {
114 let cols: Vec<String> = row.columns().iter().map(|c| c.name().to_string()).collect();
115 result_set.set_column_names(std::sync::Arc::new(cols));
116 }
117
118 for row in rows {
119 let col_count = row.columns().len();
120 let mut row_values = Vec::with_capacity(col_count);
121 for idx in 0..col_count {
122 row_values.push(postgres_extract_value(row, idx)?);
123 }
124 result_set.add_row_values(row_values);
125 }
126
127 Ok(result_set)
128}
129
130pub async fn execute_query_on_client(
135 client: &Client,
136 query: &str,
137 params: &[RowValues],
138) -> Result<ResultSet, SqlMiddlewareDbError> {
139 let converted = PgParams::convert_sql_params(params, ConversionMode::Query)?;
140 let rows = client
141 .query(query, converted.as_refs())
142 .await
143 .map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("postgres select error: {e}")))?;
144 build_result_set_from_rows(&rows)
145}
146
147pub async fn execute_dml_on_client(
152 client: &Client,
153 query: &str,
154 params: &[RowValues],
155 err_label: &str,
156) -> Result<usize, SqlMiddlewareDbError> {
157 let converted = PgParams::convert_sql_params(params, ConversionMode::Execute)?;
158 let rows = client
159 .execute(query, converted.as_refs())
160 .await
161 .map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))?;
162 usize::try_from(rows).map_err(|e| {
163 SqlMiddlewareDbError::ExecutionError(format!(
164 "postgres affected rows conversion error: {e}"
165 ))
166 })
167}