sql_middleware/
middleware.rs

1// All thiserror imports are handled by the macro
2use async_trait::async_trait;
3use chrono::NaiveDateTime;
4use clap::ValueEnum;
5use deadpool_postgres::{Object as PostgresObject, Pool as DeadpoolPostgresPool};
6use deadpool_sqlite::{rusqlite, Object as SqliteObject, Pool as DeadpoolSqlitePool};
7use rusqlite::Connection as SqliteConnectionType;
8use serde_json::Value as JsonValue;
9use thiserror::Error;
10
11use crate::{postgres, sqlite};
12pub type SqliteWritePool = DeadpoolSqlitePool;
13
14/// Wrapper around a database connection for generic code
15/// 
16/// This enum allows code to handle either PostgreSQL or SQLite 
17/// connections in a generic way.
18pub enum AnyConnWrapper<'a> {
19    /// PostgreSQL client connection
20    Postgres(&'a mut tokio_postgres::Client),
21    /// SQLite database connection
22    Sqlite(&'a mut SqliteConnectionType),
23}
24
25/// A query and its parameters bundled together
26/// 
27/// This type makes it easier to pass around a SQL query and its
28/// parameters as a single unit.
29#[derive(Debug, Clone)]
30pub struct QueryAndParams {
31    /// The SQL query string
32    pub query: String,
33    /// The parameters to be bound to the query
34    pub params: Vec<RowValues>,
35}
36
37impl QueryAndParams {
38    /// Create a new QueryAndParams with the given query string and parameters
39    ///
40    /// # Arguments
41    ///
42    /// * `query` - The SQL query string
43    /// * `params` - The parameters to bind to the query
44    ///
45    /// # Returns
46    ///
47    /// A new QueryAndParams instance
48    pub fn new(query: impl Into<String>, params: Vec<RowValues>) -> Self {
49        Self {
50            query: query.into(),
51            params,
52        }
53    }
54    
55    /// Create a new QueryAndParams with no parameters
56    ///
57    /// # Arguments
58    ///
59    /// * `query` - The SQL query string
60    ///
61    /// # Returns
62    ///
63    /// A new QueryAndParams instance with an empty parameter list
64    pub fn new_without_params(query: impl Into<String>) -> Self {
65        Self {
66            query: query.into(),
67            params: Vec::new(),
68        }
69    }
70}
71
72/// Values that can be stored in a database row or used as query parameters
73///
74/// This enum provides a unified representation of database values across
75/// different database engines.
76#[derive(Debug, Clone, PartialEq)]
77pub enum RowValues {
78    /// Integer value (64-bit)
79    Int(i64),
80    /// Floating point value (64-bit)
81    Float(f64),
82    /// Text/string value
83    Text(String),
84    /// Boolean value
85    Bool(bool),
86    /// Timestamp value
87    Timestamp(NaiveDateTime),
88    /// NULL value
89    Null,
90    /// JSON value
91    JSON(JsonValue),
92    /// Binary data
93    Blob(Vec<u8>),
94}
95
96impl RowValues {
97    /// Check if this value is NULL
98    pub fn is_null(&self) -> bool {
99        matches!(self, Self::Null)
100    }
101}
102
103/// The database type supported by this middleware
104#[derive(Debug, Clone, PartialEq, Eq, Hash, ValueEnum)]
105pub enum DatabaseType {
106    /// PostgreSQL database
107    Postgres,
108    /// SQLite database
109    Sqlite,
110}
111
112/// Connection pool for database access
113///
114/// This enum wraps the different connection pool types for the
115/// supported database engines.
116#[derive(Debug, Clone)]
117pub enum MiddlewarePool {
118    /// PostgreSQL connection pool
119    Postgres(DeadpoolPostgresPool),
120    /// SQLite connection pool
121    Sqlite(DeadpoolSqlitePool),
122}
123
124/// Configuration and connection pool for a database
125///
126/// This struct holds both the configuration and the connection pool
127/// for a database, making it easier to manage database connections.
128#[derive(Clone, Debug)]
129pub struct ConfigAndPool {
130    /// The connection pool
131    pub pool: MiddlewarePool,
132    /// The database type
133    pub db_type: DatabaseType,
134}
135
136#[derive(Debug, Error)]
137pub enum SqlMiddlewareDbError {
138    #[error(transparent)]
139    PostgresError(#[from] tokio_postgres::Error),
140    
141    #[error(transparent)]
142    SqliteError(#[from] rusqlite::Error),
143    
144    #[error(transparent)]
145    PoolErrorPostgres(#[from] deadpool::managed::PoolError<tokio_postgres::Error>),
146    
147    #[error(transparent)]
148    PoolErrorSqlite(#[from] deadpool::managed::PoolError<rusqlite::Error>),
149    
150    #[error("Configuration error: {0}")]
151    ConfigError(String),
152    
153    #[error("Connection error: {0}")]
154    ConnectionError(String),
155    
156    #[error("Parameter conversion error: {0}")]
157    ParameterError(String),
158    
159    #[error("SQL execution error: {0}")]
160    ExecutionError(String),
161    
162    #[error("Unimplemented feature: {0}")]
163    Unimplemented(String),
164    
165    #[error("Other database error: {0}")]
166    Other(String),
167}
168
169/// A row from a database query result
170///
171/// This struct represents a single row from a database query result,
172/// with access to both the column names and the values.
173#[derive(Debug, Clone)]
174pub struct CustomDbRow {
175    /// The column names for this row (shared across all rows in a result set)
176    pub column_names: std::sync::Arc<Vec<String>>,
177    /// The values for this row
178    pub rows: Vec<RowValues>,
179    // Internal cache for faster column lookups (to avoid repeated string comparisons)
180    #[doc(hidden)]
181    column_index_cache: std::sync::Arc<std::collections::HashMap<String, usize>>,
182}
183
184impl CustomDbRow {
185    /// Create a new database row
186    ///
187    /// # Arguments
188    ///
189    /// * `column_names` - The column names
190    /// * `rows` - The values for this row
191    ///
192    /// # Returns
193    ///
194    /// A new CustomDbRow instance
195    pub fn new(column_names: std::sync::Arc<Vec<String>>, rows: Vec<RowValues>) -> Self {
196        // Build a cache of column name to index for faster lookups
197        let cache = std::sync::Arc::new(
198            column_names
199                .iter()
200                .enumerate()
201                .map(|(i, name)| (name.clone(), i))
202                .collect::<std::collections::HashMap<_, _>>()
203        );
204        
205        Self { 
206            column_names, 
207            rows,
208            column_index_cache: cache,
209        }
210    }
211    
212    /// Get the index of a column by name
213    ///
214    /// # Arguments
215    ///
216    /// * `column_name` - The name of the column
217    ///
218    /// # Returns
219    ///
220    /// The index of the column, or None if not found
221    pub fn get_column_index(&self, column_name: &str) -> Option<usize> {
222        // First check the cache
223        if let Some(&idx) = self.column_index_cache.get(column_name) {
224            return Some(idx);
225        }
226        
227        // Fall back to linear search
228        self.column_names.iter().position(|col| col == column_name)
229    }
230
231    /// Get a value from the row by column name
232    ///
233    /// # Arguments
234    ///
235    /// * `column_name` - The name of the column
236    ///
237    /// # Returns
238    ///
239    /// The value at the column, or None if the column wasn't found
240    pub fn get(&self, column_name: &str) -> Option<&RowValues> {
241        let index_opt = self.get_column_index(column_name);
242        if let Some(idx) = index_opt {
243            self.rows.get(idx)
244        } else {
245            None
246        }
247    }
248    
249    /// Get a value from the row by column index
250    ///
251    /// # Arguments
252    ///
253    /// * `index` - The index of the column
254    ///
255    /// # Returns
256    ///
257    /// The value at the index, or None if the index is out of bounds
258    pub fn get_by_index(&self, index: usize) -> Option<&RowValues> {
259        self.rows.get(index)
260    }
261}
262
263/// A result set from a database query
264///
265/// This struct represents the result of a database query,
266/// containing the rows returned by the query and metadata.
267#[derive(Debug, Clone, Default)]
268pub struct ResultSet {
269    /// The rows returned by the query
270    pub results: Vec<CustomDbRow>,
271    /// The number of rows affected (for DML statements)
272    pub rows_affected: usize,
273}
274
275impl ResultSet {
276    /// Create a new result set with a known capacity
277    ///
278    /// # Arguments
279    ///
280    /// * `capacity` - The initial capacity for the result rows
281    ///
282    /// # Returns
283    ///
284    /// A new ResultSet instance with preallocated capacity
285    pub fn with_capacity(capacity: usize) -> ResultSet {
286        ResultSet {
287            results: Vec::with_capacity(capacity),
288            rows_affected: 0,
289        }
290    }
291    
292    /// Add a row to the result set
293    ///
294    /// # Arguments
295    ///
296    /// * `row` - The row to add
297    pub fn add_row(&mut self, row: CustomDbRow) {
298        self.results.push(row);
299        self.rows_affected += 1;
300    }
301}
302
303impl MiddlewarePool {
304    // Return a reference to self instead of cloning the entire pool
305    pub async fn get(&self) -> Result<&MiddlewarePool, SqlMiddlewareDbError> {
306        Ok(self)
307    }
308    pub async fn get_connection(
309        pool: &MiddlewarePool,
310    ) -> Result<MiddlewarePoolConnection, SqlMiddlewareDbError> {
311        match pool {
312            MiddlewarePool::Postgres(pool) => {
313                let conn: PostgresObject = pool
314                    .get()
315                    .await
316                    .map_err(SqlMiddlewareDbError::PoolErrorPostgres)?;
317                Ok(MiddlewarePoolConnection::Postgres(conn))
318            }
319            MiddlewarePool::Sqlite(pool) => {
320                let conn: SqliteObject = pool
321                    .get()
322                    .await
323                    .map_err(SqlMiddlewareDbError::PoolErrorSqlite)?;
324                Ok(MiddlewarePoolConnection::Sqlite(conn))
325            }
326        }
327    }
328}
329
330#[derive(Debug)]
331pub enum MiddlewarePoolConnection {
332    Postgres(PostgresObject),
333    Sqlite(SqliteObject),
334}
335
336impl MiddlewarePoolConnection {
337    pub async fn interact_async<F, Fut>(
338        &mut self,
339        func: F,
340    ) -> Result<Fut::Output, SqlMiddlewareDbError>
341    where
342        F: FnOnce(AnyConnWrapper<'_>) -> Fut + Send + 'static,
343        Fut: std::future::Future<Output = Result<(), SqlMiddlewareDbError>> + Send + 'static,
344    {
345        match self {
346            MiddlewarePoolConnection::Postgres(pg_obj) => {
347                // Assuming PostgresObject dereferences to tokio_postgres::Client
348                let client: &mut tokio_postgres::Client = pg_obj.as_mut();
349                Ok(func(AnyConnWrapper::Postgres(client)).await)
350            }
351            MiddlewarePoolConnection::Sqlite(_) => {
352                Err(SqlMiddlewareDbError::Unimplemented(
353                    "interact_async is not supported for SQLite; use interact_sync instead".to_string()
354                ))
355            }
356        }
357    }
358
359    pub async fn interact_sync<F, R>(&self, f: F) -> Result<R, SqlMiddlewareDbError>
360    where
361        F: FnOnce(AnyConnWrapper) -> R + Send + 'static,
362        R: Send + 'static,
363    {
364        match self {
365            MiddlewarePoolConnection::Sqlite(sqlite_obj) => {
366                // Use `deadpool_sqlite`'s `interact` method
367                sqlite_obj
368                    .interact(move |conn| {
369                        let wrapper = AnyConnWrapper::Sqlite(conn);
370                        Ok(f(wrapper))
371                    })
372                    .await?
373            }
374            MiddlewarePoolConnection::Postgres(_) => {
375                Err(SqlMiddlewareDbError::Unimplemented(
376                    "interact_sync is not supported for Postgres; use interact_async instead".to_string()
377                ))
378            }
379        }
380    }
381}
382
383// ----------------------------------------
384// Common impl blocks for DbError
385// ----------------------------------------
386// We don't need this anymore as thiserror already generates a Display implementation
387// The #[error] attributes on the enum variants define the format for each variant
388// This is commented out to show what was here originally
389/*
390impl fmt::Display for SqlMiddlewareDbError {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        match self {
393            SqlMiddlewareDbError::PostgresError(e) => write!(f, "PostgresError: {}", e),
394            SqlMiddlewareDbError::SqliteError(e) => write!(f, "SqliteError: {}", e),
395            SqlMiddlewareDbError::ConfigError(msg) => write!(f, "ConfigError: {}", msg),
396            SqlMiddlewareDbError::ConnectionError(msg) => write!(f, "ConnectionError: {}", msg),
397            SqlMiddlewareDbError::ParameterError(msg) => write!(f, "ParameterError: {}", msg),
398            SqlMiddlewareDbError::ExecutionError(msg) => write!(f, "ExecutionError: {}", msg),
399            SqlMiddlewareDbError::Unimplemented(msg) => write!(f, "Unimplemented: {}", msg),
400            SqlMiddlewareDbError::Other(msg) => write!(f, "Other: {}", msg),
401            SqlMiddlewareDbError::PoolErrorSqlite(e) => write!(f, "PoolError: {:?}", e),
402            SqlMiddlewareDbError::PoolErrorPostgres(pool_error) => {
403                write!(f, "PoolErrorPostgres: {:?}", pool_error)
404            }
405        }
406    }
407}
408*/
409
410// ----------------------------------------
411// Impl for RowValues that is DB-agnostic
412// ----------------------------------------
413impl RowValues {
414    pub fn as_int(&self) -> Option<&i64> {
415        if let RowValues::Int(value) = self {
416            Some(value)
417        } else {
418            None
419        }
420    }
421
422    pub fn as_text(&self) -> Option<&str> {
423        if let RowValues::Text(value) = self {
424            Some(value)
425        } else {
426            None
427        }
428    }
429
430    pub fn as_bool(&self) -> Option<&bool> {
431        if let RowValues::Bool(value) = self {
432            return Some(value);
433        } else if let Some(i) = self.as_int() {
434            if *i == 1 {
435                return Some(&true);
436            } else if *i == 0 {
437                return Some(&false);
438            }
439        }
440        None
441    }
442
443    pub fn as_timestamp(&self) -> Option<chrono::NaiveDateTime> {
444        if let RowValues::Timestamp(value) = self {
445            return Some(value.clone());
446        } else if let Some(s) = self.as_text() {
447            // Try "YYYY-MM-DD HH:MM:SS"
448            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
449                return Some(dt);
450            }
451            // Try "YYYY-MM-DD HH:MM:SS.SSS"
452            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S.%3f") {
453                return Some(dt);
454            }
455        }
456        None
457    }
458
459    pub fn as_float(&self) -> Option<f64> {
460        if let RowValues::Float(value) = self {
461            Some(*value)
462        } else {
463            None
464        }
465    }
466
467    pub fn as_blob(&self) -> Option<&[u8]> {
468        if let RowValues::Blob(bytes) = self {
469            Some(bytes)
470        } else {
471            None
472        }
473    }
474}
475
476#[async_trait]
477pub trait AsyncDatabaseExecutor {
478    /// Executes a batch of SQL queries (can be a mix of reads/writes) within a transaction. No parameters are supported.
479    async fn execute_batch(&mut self, query: &str) -> Result<(), SqlMiddlewareDbError>;
480
481    /// Executes a single SELECT statement and returns the result set.
482    async fn execute_select(
483        &mut self,
484        query: &str,
485        params: &[RowValues],
486    ) -> Result<ResultSet, SqlMiddlewareDbError>;
487
488    /// Executes a single DML statement (INSERT, UPDATE, DELETE, etc.) and returns the number of rows affected.
489    async fn execute_dml(
490        &mut self,
491        query: &str,
492        params: &[RowValues],
493    ) -> Result<usize, SqlMiddlewareDbError>;
494}
495
496#[async_trait]
497impl AsyncDatabaseExecutor for MiddlewarePoolConnection {
498    /// Executes a batch of SQL queries within a transaction by delegating to the specific database module.
499    async fn execute_batch(&mut self, query: &str) -> Result<(), SqlMiddlewareDbError> {
500        match self {
501            MiddlewarePoolConnection::Postgres(pg_client) => {
502                postgres::execute_batch(pg_client, query).await
503            }
504            MiddlewarePoolConnection::Sqlite(sqlite_client) => {
505                sqlite::execute_batch(sqlite_client, query).await
506            }
507        }
508    }
509    async fn execute_select(
510        &mut self,
511        query: &str,
512        params: &[RowValues],
513    ) -> Result<ResultSet, SqlMiddlewareDbError> {
514        match self {
515            MiddlewarePoolConnection::Postgres(pg_client) => {
516                postgres::execute_select(pg_client, query, params).await
517            }
518            MiddlewarePoolConnection::Sqlite(sqlite_client) => {
519                sqlite::execute_select(sqlite_client, query, params).await
520            }
521        }
522    }
523    async fn execute_dml(
524        &mut self,
525        query: &str,
526        params: &[RowValues],
527    ) -> Result<usize, SqlMiddlewareDbError> {
528        match self {
529            MiddlewarePoolConnection::Postgres(pg_client) => {
530                postgres::execute_dml(pg_client, query, &params).await
531            }
532            MiddlewarePoolConnection::Sqlite(sqlite_client) => {
533                sqlite::execute_dml(sqlite_client, query, params).await
534            }
535        }
536    }
537}
538
539/// Convert a slice of RowValues into database-specific parameters.
540/// This trait provides a unified interface for converting generic RowValues
541/// to database-specific parameter types.
542pub trait ParamConverter<'a> {
543    type Converted;
544
545    /// Convert a slice of RowValues into the backend’s parameter type.
546    fn convert_sql_params(
547        params: &'a [RowValues],
548        mode: ConversionMode,
549    ) -> Result<Self::Converted, SqlMiddlewareDbError>;
550    
551    /// Check if this converter supports the given mode
552    /// 
553    /// # Arguments
554    /// * `mode` - The conversion mode to check
555    /// 
556    /// # Returns
557    /// * `bool` - Whether this converter supports the mode
558    fn supports_mode(_mode: ConversionMode) -> bool {
559        true // By default, support both modes
560    }
561}
562
563/// The conversion "mode".
564#[derive(Debug, Clone, Copy, PartialEq)]
565pub enum ConversionMode {
566    /// When the converted parameters will be used in a query (SELECT)
567    Query,
568    /// When the converted parameters will be used for statement execution (INSERT/UPDATE/etc.)
569    Execute,
570}