sql_middleware/
postgres.rs

1// postgres.rs
2use std::error::Error;
3
4use crate::middleware::{
5    ConfigAndPool, ConversionMode, CustomDbRow, DatabaseType, MiddlewarePool, ParamConverter,
6    ResultSet, RowValues, SqlMiddlewareDbError,
7};
8use chrono::NaiveDateTime;
9use deadpool_postgres::Transaction;
10use deadpool_postgres::{Config as PgConfig, Object};
11use serde_json::Value;
12use tokio_postgres::{
13    types::{to_sql_checked, IsNull, ToSql, Type},
14    NoTls, Statement,
15};
16use tokio_util::bytes;
17
18// The #[from] attribute on the SqlMiddlewareDbError::PostgresError variant
19// automatically generates this implementation
20
21impl ConfigAndPool {
22    /// Asynchronous initializer for ConfigAndPool with Postgres
23    pub async fn new_postgres(pg_config: PgConfig) -> Result<Self, SqlMiddlewareDbError> {
24        // Validate all required config fields are present
25        if pg_config.dbname.is_none() {
26            return Err(SqlMiddlewareDbError::ConfigError("dbname is required".to_string()));
27        }
28
29        if pg_config.host.is_none() {
30            return Err(SqlMiddlewareDbError::ConfigError("host is required".to_string()));
31        }
32        if pg_config.port.is_none() {
33            return Err(SqlMiddlewareDbError::ConfigError("port is required".to_string()));
34        }
35        if pg_config.user.is_none() {
36            return Err(SqlMiddlewareDbError::ConfigError("user is required".to_string()));
37        }
38        if pg_config.password.is_none() {
39            return Err(SqlMiddlewareDbError::ConfigError("password is required".to_string()));
40        }
41
42        // Attempt to create connection pool
43        let pg_pool = pg_config
44            .create_pool(Some(deadpool_postgres::Runtime::Tokio1), NoTls)
45            .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("Failed to create Postgres pool: {}", e)))?;
46            
47        Ok(ConfigAndPool {
48            pool: MiddlewarePool::Postgres(pg_pool),
49            db_type: DatabaseType::Postgres,
50        })
51    }
52}
53
54/// Container for Postgres parameters with lifetime tracking
55pub struct Params<'a> {
56    references: Vec<&'a (dyn ToSql + Sync)>,
57}
58
59impl<'a> Params<'a> {
60    /// Convert from a slice of RowValues to Postgres parameters
61    pub fn convert(params: &'a [RowValues]) -> Result<Params<'a>, SqlMiddlewareDbError> {
62        let references: Vec<&(dyn ToSql + Sync)> =
63            params.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
64
65        Ok(Params { references })
66    }
67
68    /// Convert a Vec of RowValues for batch operations
69    pub fn convert_for_batch(
70        params: &'a Vec<RowValues>,
71    ) -> Result<Vec<&'a (dyn ToSql + Sync + 'a)>, SqlMiddlewareDbError> {
72        let mut references = Vec::new();
73        for p in params {
74            references.push(p as &(dyn ToSql + Sync));
75        }
76
77        Ok(references)
78    }
79
80    /// Get a reference to the underlying parameter array
81    pub fn as_refs(&self) -> &[&(dyn ToSql + Sync)] {
82        &self.references
83    }
84}
85
86impl<'a> ParamConverter<'a> for Params<'a> {
87    type Converted = Params<'a>;
88
89    fn convert_sql_params(
90        params: &'a [RowValues],
91        _mode: ConversionMode,
92    ) -> Result<Self::Converted, SqlMiddlewareDbError> {
93        // Simply delegate to your existing conversion:
94        Self::convert(params)
95    }
96    
97    // PostgresParams supports both query and execution modes
98    fn supports_mode(_mode: ConversionMode) -> bool {
99        true
100    }
101}
102
103impl ToSql for RowValues {
104    fn to_sql(
105        &self,
106        ty: &Type,
107        out: &mut bytes::BytesMut,
108    ) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
109        match self {
110            RowValues::Int(i) => (*i).to_sql(ty, out),
111            RowValues::Float(f) => (*f).to_sql(ty, out),
112            RowValues::Text(s) => s.to_sql(ty, out),
113            RowValues::Bool(b) => (*b).to_sql(ty, out),
114            RowValues::Timestamp(dt) => dt.to_sql(ty, out),
115            RowValues::Null => Ok(IsNull::Yes),
116            RowValues::JSON(jsval) => jsval.to_sql(ty, out),
117            RowValues::Blob(bytes) => bytes.to_sql(ty, out),
118        }
119    }
120
121    fn accepts(ty: &Type) -> bool {
122        // Only accept types we can properly handle
123        match *ty {
124            // Integer types
125            Type::INT2 | Type::INT4 | Type::INT8 => true,
126            // Floating point types
127            Type::FLOAT4 | Type::FLOAT8 => true,
128            // Text types
129            Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => true,
130            // Boolean type
131            Type::BOOL => true,
132            // Date/time types
133            Type::TIMESTAMP | Type::TIMESTAMPTZ | Type::DATE => true,
134            // JSON types
135            Type::JSON | Type::JSONB => true,
136            // Binary data
137            Type::BYTEA => true,
138            // For any other type, we don't accept
139            _ => false,
140        }
141    }
142
143    to_sql_checked!();
144}
145
146/// Build a result set from a Postgres query execution
147pub async fn build_result_set<'a>(
148    stmt: &Statement,
149    params: &[&(dyn ToSql + Sync)],
150    transaction: &Transaction<'a>,
151) -> Result<ResultSet, SqlMiddlewareDbError> {
152    // Execute the query
153    let rows = transaction
154        .query(stmt, params)
155        .await?;
156
157    let column_names: Vec<String> = stmt
158        .columns()
159        .iter()
160        .map(|col| col.name().to_string())
161        .collect();
162
163    // Preallocate capacity if we can estimate the number of rows
164    let capacity = rows.len();
165    let mut result_set = ResultSet::with_capacity(capacity);
166    // Store column names once in the result set
167    let column_names_rc = std::sync::Arc::new(column_names);
168
169    for row in rows {
170        let mut row_values = Vec::new();
171
172        for i in 0..column_names_rc.len() {
173            let value = postgres_extract_value(&row, i)?;
174            row_values.push(value);
175        }
176
177        result_set.add_row(CustomDbRow::new(column_names_rc.clone(), row_values));
178
179        result_set.rows_affected += 1;
180    }
181
182    Ok(result_set)
183}
184
185/// Extracts a RowValues from a tokio_postgres Row at the given index
186fn postgres_extract_value(
187    row: &tokio_postgres::Row,
188    idx: usize,
189) -> Result<RowValues, SqlMiddlewareDbError> {
190    // Determine the type of the column and extract accordingly
191    let type_info = row.columns()[idx].type_();
192
193    // Match on the type based on PostgreSQL type OIDs or names
194    // For simplicity, we'll handle common types. You may need to expand this.
195    if type_info.name() == "int4" || type_info.name() == "int8" {
196        let val: Option<i64> = row
197            .try_get(idx)?;
198        Ok(val.map_or(RowValues::Null, RowValues::Int))
199    } else if type_info.name() == "float4" || type_info.name() == "float8" {
200        let val: Option<f64> = row
201            .try_get(idx)?;
202        Ok(val.map_or(RowValues::Null, RowValues::Float))
203    } else if type_info.name() == "bool" {
204        let val: Option<bool> = row
205            .try_get(idx)?;
206        Ok(val.map_or(RowValues::Null, RowValues::Bool))
207    } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
208        let val: Option<NaiveDateTime> = row
209            .try_get(idx)?;
210        Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
211    } else if type_info.name() == "json" || type_info.name() == "jsonb" {
212        let val: Option<Value> = row
213            .try_get(idx)?;
214        Ok(val.map_or(RowValues::Null, RowValues::JSON))
215    } else if type_info.name() == "bytea" {
216        let val: Option<Vec<u8>> = row
217            .try_get(idx)?;
218        Ok(val.map_or(RowValues::Null, RowValues::Blob))
219    } else if type_info.name() == "text"
220        || type_info.name() == "varchar"
221        || type_info.name() == "char"
222    {
223        let val: Option<String> = row
224            .try_get(idx)?;
225        Ok(val.map_or(RowValues::Null, RowValues::Text))
226    } else {
227        // For other types, attempt to get as string
228        let val: Option<String> = row
229            .try_get(idx)?;
230        Ok(val.map_or(RowValues::Null, RowValues::Text))
231    }
232}
233
234/// Execute a batch of SQL statements for Postgres
235pub async fn execute_batch(
236    pg_client: &mut Object,
237    query: &str,
238) -> Result<(), SqlMiddlewareDbError> {
239    // Begin a transaction
240    let tx = pg_client.transaction().await?;
241
242    // Execute the batch of queries
243    tx.batch_execute(query).await?;
244
245    // Commit the transaction
246    tx.commit().await?;
247
248    Ok(())
249}
250
251/// Execute a SELECT query with parameters
252pub async fn execute_select(
253    pg_client: &mut Object,
254    query: &str,
255    params: &[RowValues],
256) -> Result<ResultSet, SqlMiddlewareDbError> {
257    let params = Params::convert(params)?;
258    let tx = pg_client.transaction().await?;
259    let stmt = tx.prepare(query).await?;
260    let result_set = build_result_set(&stmt, params.as_refs(), &tx).await?;
261    tx.commit().await?;
262    Ok(result_set)
263}
264
265/// Execute a DML query (INSERT, UPDATE, DELETE) with parameters
266pub async fn execute_dml(
267    pg_client: &mut Object,
268    query: &str,
269    params: &[RowValues],
270) -> Result<usize, SqlMiddlewareDbError> {
271    let params = Params::convert(params)?;
272    let tx = pg_client.transaction().await?;
273
274    let stmt = tx.prepare(query).await?;
275    let rows = tx.execute(&stmt, params.as_refs()).await?;
276    tx.commit().await?;
277
278    Ok(rows as usize)
279}