1use std::error::Error;
3
4use crate::middleware::{
5 ConfigAndPool, ConversionMode, CustomDbRow, DatabaseType, MiddlewarePool, ParamConverter,
6 ResultSet, RowValues, SqlMiddlewareDbError,
7};
8use chrono::NaiveDateTime;
9use deadpool_postgres::Transaction;
10use deadpool_postgres::{Config as PgConfig, Object};
11use serde_json::Value;
12use tokio_postgres::{
13 types::{to_sql_checked, IsNull, ToSql, Type},
14 NoTls, Statement,
15};
16use tokio_util::bytes;
17
18impl ConfigAndPool {
22 pub async fn new_postgres(pg_config: PgConfig) -> Result<Self, SqlMiddlewareDbError> {
24 if pg_config.dbname.is_none() {
26 return Err(SqlMiddlewareDbError::ConfigError("dbname is required".to_string()));
27 }
28
29 if pg_config.host.is_none() {
30 return Err(SqlMiddlewareDbError::ConfigError("host is required".to_string()));
31 }
32 if pg_config.port.is_none() {
33 return Err(SqlMiddlewareDbError::ConfigError("port is required".to_string()));
34 }
35 if pg_config.user.is_none() {
36 return Err(SqlMiddlewareDbError::ConfigError("user is required".to_string()));
37 }
38 if pg_config.password.is_none() {
39 return Err(SqlMiddlewareDbError::ConfigError("password is required".to_string()));
40 }
41
42 let pg_pool = pg_config
44 .create_pool(Some(deadpool_postgres::Runtime::Tokio1), NoTls)
45 .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("Failed to create Postgres pool: {}", e)))?;
46
47 Ok(ConfigAndPool {
48 pool: MiddlewarePool::Postgres(pg_pool),
49 db_type: DatabaseType::Postgres,
50 })
51 }
52}
53
54pub struct Params<'a> {
56 references: Vec<&'a (dyn ToSql + Sync)>,
57}
58
59impl<'a> Params<'a> {
60 pub fn convert(params: &'a [RowValues]) -> Result<Params<'a>, SqlMiddlewareDbError> {
62 let references: Vec<&(dyn ToSql + Sync)> =
63 params.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
64
65 Ok(Params { references })
66 }
67
68 pub fn convert_for_batch(
70 params: &'a Vec<RowValues>,
71 ) -> Result<Vec<&'a (dyn ToSql + Sync + 'a)>, SqlMiddlewareDbError> {
72 let mut references = Vec::new();
73 for p in params {
74 references.push(p as &(dyn ToSql + Sync));
75 }
76
77 Ok(references)
78 }
79
80 pub fn as_refs(&self) -> &[&(dyn ToSql + Sync)] {
82 &self.references
83 }
84}
85
86impl<'a> ParamConverter<'a> for Params<'a> {
87 type Converted = Params<'a>;
88
89 fn convert_sql_params(
90 params: &'a [RowValues],
91 _mode: ConversionMode,
92 ) -> Result<Self::Converted, SqlMiddlewareDbError> {
93 Self::convert(params)
95 }
96
97 fn supports_mode(_mode: ConversionMode) -> bool {
99 true
100 }
101}
102
103impl ToSql for RowValues {
104 fn to_sql(
105 &self,
106 ty: &Type,
107 out: &mut bytes::BytesMut,
108 ) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
109 match self {
110 RowValues::Int(i) => (*i).to_sql(ty, out),
111 RowValues::Float(f) => (*f).to_sql(ty, out),
112 RowValues::Text(s) => s.to_sql(ty, out),
113 RowValues::Bool(b) => (*b).to_sql(ty, out),
114 RowValues::Timestamp(dt) => dt.to_sql(ty, out),
115 RowValues::Null => Ok(IsNull::Yes),
116 RowValues::JSON(jsval) => jsval.to_sql(ty, out),
117 RowValues::Blob(bytes) => bytes.to_sql(ty, out),
118 }
119 }
120
121 fn accepts(ty: &Type) -> bool {
122 match *ty {
124 Type::INT2 | Type::INT4 | Type::INT8 => true,
126 Type::FLOAT4 | Type::FLOAT8 => true,
128 Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => true,
130 Type::BOOL => true,
132 Type::TIMESTAMP | Type::TIMESTAMPTZ | Type::DATE => true,
134 Type::JSON | Type::JSONB => true,
136 Type::BYTEA => true,
138 _ => false,
140 }
141 }
142
143 to_sql_checked!();
144}
145
146pub async fn build_result_set<'a>(
148 stmt: &Statement,
149 params: &[&(dyn ToSql + Sync)],
150 transaction: &Transaction<'a>,
151) -> Result<ResultSet, SqlMiddlewareDbError> {
152 let rows = transaction
154 .query(stmt, params)
155 .await?;
156
157 let column_names: Vec<String> = stmt
158 .columns()
159 .iter()
160 .map(|col| col.name().to_string())
161 .collect();
162
163 let capacity = rows.len();
165 let mut result_set = ResultSet::with_capacity(capacity);
166 let column_names_rc = std::sync::Arc::new(column_names);
168
169 for row in rows {
170 let mut row_values = Vec::new();
171
172 for i in 0..column_names_rc.len() {
173 let value = postgres_extract_value(&row, i)?;
174 row_values.push(value);
175 }
176
177 result_set.add_row(CustomDbRow::new(column_names_rc.clone(), row_values));
178
179 result_set.rows_affected += 1;
180 }
181
182 Ok(result_set)
183}
184
185fn postgres_extract_value(
187 row: &tokio_postgres::Row,
188 idx: usize,
189) -> Result<RowValues, SqlMiddlewareDbError> {
190 let type_info = row.columns()[idx].type_();
192
193 if type_info.name() == "int4" || type_info.name() == "int8" {
196 let val: Option<i64> = row
197 .try_get(idx)?;
198 Ok(val.map_or(RowValues::Null, RowValues::Int))
199 } else if type_info.name() == "float4" || type_info.name() == "float8" {
200 let val: Option<f64> = row
201 .try_get(idx)?;
202 Ok(val.map_or(RowValues::Null, RowValues::Float))
203 } else if type_info.name() == "bool" {
204 let val: Option<bool> = row
205 .try_get(idx)?;
206 Ok(val.map_or(RowValues::Null, RowValues::Bool))
207 } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
208 let val: Option<NaiveDateTime> = row
209 .try_get(idx)?;
210 Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
211 } else if type_info.name() == "json" || type_info.name() == "jsonb" {
212 let val: Option<Value> = row
213 .try_get(idx)?;
214 Ok(val.map_or(RowValues::Null, RowValues::JSON))
215 } else if type_info.name() == "bytea" {
216 let val: Option<Vec<u8>> = row
217 .try_get(idx)?;
218 Ok(val.map_or(RowValues::Null, RowValues::Blob))
219 } else if type_info.name() == "text"
220 || type_info.name() == "varchar"
221 || type_info.name() == "char"
222 {
223 let val: Option<String> = row
224 .try_get(idx)?;
225 Ok(val.map_or(RowValues::Null, RowValues::Text))
226 } else {
227 let val: Option<String> = row
229 .try_get(idx)?;
230 Ok(val.map_or(RowValues::Null, RowValues::Text))
231 }
232}
233
234pub async fn execute_batch(
236 pg_client: &mut Object,
237 query: &str,
238) -> Result<(), SqlMiddlewareDbError> {
239 let tx = pg_client.transaction().await?;
241
242 tx.batch_execute(query).await?;
244
245 tx.commit().await?;
247
248 Ok(())
249}
250
251pub async fn execute_select(
253 pg_client: &mut Object,
254 query: &str,
255 params: &[RowValues],
256) -> Result<ResultSet, SqlMiddlewareDbError> {
257 let params = Params::convert(params)?;
258 let tx = pg_client.transaction().await?;
259 let stmt = tx.prepare(query).await?;
260 let result_set = build_result_set(&stmt, params.as_refs(), &tx).await?;
261 tx.commit().await?;
262 Ok(result_set)
263}
264
265pub async fn execute_dml(
267 pg_client: &mut Object,
268 query: &str,
269 params: &[RowValues],
270) -> Result<usize, SqlMiddlewareDbError> {
271 let params = Params::convert(params)?;
272 let tx = pg_client.transaction().await?;
273
274 let stmt = tx.prepare(query).await?;
275 let rows = tx.execute(&stmt, params.as_refs()).await?;
276 tx.commit().await?;
277
278 Ok(rows as usize)
279}