sql_middleware/middleware.rs
1// All thiserror imports are handled by the macro
2use async_trait::async_trait;
3use chrono::NaiveDateTime;
4use clap::ValueEnum;
5use deadpool_postgres::{Object as PostgresObject, Pool as DeadpoolPostgresPool};
6use deadpool_sqlite::{rusqlite, Object as SqliteObject, Pool as DeadpoolSqlitePool};
7use rusqlite::Connection as SqliteConnectionType;
8use serde_json::Value as JsonValue;
9use thiserror::Error;
10
11use crate::{postgres, sqlite};
12pub type SqliteWritePool = DeadpoolSqlitePool;
13
14/// Wrapper around a database connection for generic code
15///
16/// This enum allows code to handle either PostgreSQL or SQLite
17/// connections in a generic way.
18pub enum AnyConnWrapper<'a> {
19 /// PostgreSQL client connection
20 Postgres(&'a mut tokio_postgres::Client),
21 /// SQLite database connection
22 Sqlite(&'a mut SqliteConnectionType),
23}
24
25/// A query and its parameters bundled together
26///
27/// This type makes it easier to pass around a SQL query and its
28/// parameters as a single unit.
29#[derive(Debug, Clone)]
30pub struct QueryAndParams {
31 /// The SQL query string
32 pub query: String,
33 /// The parameters to be bound to the query
34 pub params: Vec<RowValues>,
35}
36
37impl QueryAndParams {
38 /// Create a new QueryAndParams with the given query string and parameters
39 ///
40 /// # Arguments
41 ///
42 /// * `query` - The SQL query string
43 /// * `params` - The parameters to bind to the query
44 ///
45 /// # Returns
46 ///
47 /// A new QueryAndParams instance
48 pub fn new(query: impl Into<String>, params: Vec<RowValues>) -> Self {
49 Self {
50 query: query.into(),
51 params,
52 }
53 }
54
55 /// Create a new QueryAndParams with no parameters
56 ///
57 /// # Arguments
58 ///
59 /// * `query` - The SQL query string
60 ///
61 /// # Returns
62 ///
63 /// A new QueryAndParams instance with an empty parameter list
64 pub fn new_without_params(query: impl Into<String>) -> Self {
65 Self {
66 query: query.into(),
67 params: Vec::new(),
68 }
69 }
70}
71
72/// Values that can be stored in a database row or used as query parameters
73///
74/// This enum provides a unified representation of database values across
75/// different database engines.
76#[derive(Debug, Clone, PartialEq)]
77pub enum RowValues {
78 /// Integer value (64-bit)
79 Int(i64),
80 /// Floating point value (64-bit)
81 Float(f64),
82 /// Text/string value
83 Text(String),
84 /// Boolean value
85 Bool(bool),
86 /// Timestamp value
87 Timestamp(NaiveDateTime),
88 /// NULL value
89 Null,
90 /// JSON value
91 JSON(JsonValue),
92 /// Binary data
93 Blob(Vec<u8>),
94}
95
96impl RowValues {
97 /// Check if this value is NULL
98 pub fn is_null(&self) -> bool {
99 matches!(self, Self::Null)
100 }
101}
102
103/// The database type supported by this middleware
104#[derive(Debug, Clone, PartialEq, Eq, Hash, ValueEnum)]
105pub enum DatabaseType {
106 /// PostgreSQL database
107 Postgres,
108 /// SQLite database
109 Sqlite,
110}
111
112/// Connection pool for database access
113///
114/// This enum wraps the different connection pool types for the
115/// supported database engines.
116#[derive(Debug, Clone)]
117pub enum MiddlewarePool {
118 /// PostgreSQL connection pool
119 Postgres(DeadpoolPostgresPool),
120 /// SQLite connection pool
121 Sqlite(DeadpoolSqlitePool),
122}
123
124/// Configuration and connection pool for a database
125///
126/// This struct holds both the configuration and the connection pool
127/// for a database, making it easier to manage database connections.
128#[derive(Clone, Debug)]
129pub struct ConfigAndPool {
130 /// The connection pool
131 pub pool: MiddlewarePool,
132 /// The database type
133 pub db_type: DatabaseType,
134}
135
136#[derive(Debug, Error)]
137pub enum SqlMiddlewareDbError {
138 #[error(transparent)]
139 PostgresError(#[from] tokio_postgres::Error),
140
141 #[error(transparent)]
142 SqliteError(#[from] rusqlite::Error),
143
144 #[error(transparent)]
145 PoolErrorPostgres(#[from] deadpool::managed::PoolError<tokio_postgres::Error>),
146
147 #[error(transparent)]
148 PoolErrorSqlite(#[from] deadpool::managed::PoolError<rusqlite::Error>),
149
150 #[error("Configuration error: {0}")]
151 ConfigError(String),
152
153 #[error("Connection error: {0}")]
154 ConnectionError(String),
155
156 #[error("Parameter conversion error: {0}")]
157 ParameterError(String),
158
159 #[error("SQL execution error: {0}")]
160 ExecutionError(String),
161
162 #[error("Unimplemented feature: {0}")]
163 Unimplemented(String),
164
165 #[error("Other database error: {0}")]
166 Other(String),
167}
168
169/// A row from a database query result
170///
171/// This struct represents a single row from a database query result,
172/// with access to both the column names and the values.
173#[derive(Debug, Clone)]
174pub struct CustomDbRow {
175 /// The column names for this row (shared across all rows in a result set)
176 pub column_names: std::sync::Arc<Vec<String>>,
177 /// The values for this row
178 pub rows: Vec<RowValues>,
179 // Internal cache for faster column lookups (to avoid repeated string comparisons)
180 #[doc(hidden)]
181 column_index_cache: std::sync::Arc<std::collections::HashMap<String, usize>>,
182}
183
184impl CustomDbRow {
185 /// Create a new database row
186 ///
187 /// # Arguments
188 ///
189 /// * `column_names` - The column names
190 /// * `rows` - The values for this row
191 ///
192 /// # Returns
193 ///
194 /// A new CustomDbRow instance
195 pub fn new(column_names: std::sync::Arc<Vec<String>>, rows: Vec<RowValues>) -> Self {
196 // Build a cache of column name to index for faster lookups
197 let cache = std::sync::Arc::new(
198 column_names
199 .iter()
200 .enumerate()
201 .map(|(i, name)| (name.clone(), i))
202 .collect::<std::collections::HashMap<_, _>>()
203 );
204
205 Self {
206 column_names,
207 rows,
208 column_index_cache: cache,
209 }
210 }
211
212 /// Get the index of a column by name
213 ///
214 /// # Arguments
215 ///
216 /// * `column_name` - The name of the column
217 ///
218 /// # Returns
219 ///
220 /// The index of the column, or None if not found
221 pub fn get_column_index(&self, column_name: &str) -> Option<usize> {
222 // First check the cache
223 if let Some(&idx) = self.column_index_cache.get(column_name) {
224 return Some(idx);
225 }
226
227 // Fall back to linear search
228 self.column_names.iter().position(|col| col == column_name)
229 }
230
231 /// Get a value from the row by column name
232 ///
233 /// # Arguments
234 ///
235 /// * `column_name` - The name of the column
236 ///
237 /// # Returns
238 ///
239 /// The value at the column, or None if the column wasn't found
240 pub fn get(&self, column_name: &str) -> Option<&RowValues> {
241 let index_opt = self.get_column_index(column_name);
242 if let Some(idx) = index_opt {
243 self.rows.get(idx)
244 } else {
245 None
246 }
247 }
248
249 /// Get a value from the row by column index
250 ///
251 /// # Arguments
252 ///
253 /// * `index` - The index of the column
254 ///
255 /// # Returns
256 ///
257 /// The value at the index, or None if the index is out of bounds
258 pub fn get_by_index(&self, index: usize) -> Option<&RowValues> {
259 self.rows.get(index)
260 }
261}
262
263/// A result set from a database query
264///
265/// This struct represents the result of a database query,
266/// containing the rows returned by the query and metadata.
267#[derive(Debug, Clone, Default)]
268pub struct ResultSet {
269 /// The rows returned by the query
270 pub results: Vec<CustomDbRow>,
271 /// The number of rows affected (for DML statements)
272 pub rows_affected: usize,
273}
274
275impl ResultSet {
276 /// Create a new result set with a known capacity
277 ///
278 /// # Arguments
279 ///
280 /// * `capacity` - The initial capacity for the result rows
281 ///
282 /// # Returns
283 ///
284 /// A new ResultSet instance with preallocated capacity
285 pub fn with_capacity(capacity: usize) -> ResultSet {
286 ResultSet {
287 results: Vec::with_capacity(capacity),
288 rows_affected: 0,
289 }
290 }
291
292 /// Add a row to the result set
293 ///
294 /// # Arguments
295 ///
296 /// * `row` - The row to add
297 pub fn add_row(&mut self, row: CustomDbRow) {
298 self.results.push(row);
299 self.rows_affected += 1;
300 }
301}
302
303impl MiddlewarePool {
304 // Return a reference to self instead of cloning the entire pool
305 pub async fn get(&self) -> Result<&MiddlewarePool, SqlMiddlewareDbError> {
306 Ok(self)
307 }
308 pub async fn get_connection(
309 pool: &MiddlewarePool,
310 ) -> Result<MiddlewarePoolConnection, SqlMiddlewareDbError> {
311 match pool {
312 MiddlewarePool::Postgres(pool) => {
313 let conn: PostgresObject = pool
314 .get()
315 .await
316 .map_err(SqlMiddlewareDbError::PoolErrorPostgres)?;
317 Ok(MiddlewarePoolConnection::Postgres(conn))
318 }
319 MiddlewarePool::Sqlite(pool) => {
320 let conn: SqliteObject = pool
321 .get()
322 .await
323 .map_err(SqlMiddlewareDbError::PoolErrorSqlite)?;
324 Ok(MiddlewarePoolConnection::Sqlite(conn))
325 }
326 }
327 }
328}
329
330#[derive(Debug)]
331pub enum MiddlewarePoolConnection {
332 Postgres(PostgresObject),
333 Sqlite(SqliteObject),
334}
335
336impl MiddlewarePoolConnection {
337 pub async fn interact_async<F, Fut>(
338 &mut self,
339 func: F,
340 ) -> Result<Fut::Output, SqlMiddlewareDbError>
341 where
342 F: FnOnce(AnyConnWrapper<'_>) -> Fut + Send + 'static,
343 Fut: std::future::Future<Output = Result<(), SqlMiddlewareDbError>> + Send + 'static,
344 {
345 match self {
346 MiddlewarePoolConnection::Postgres(pg_obj) => {
347 // Assuming PostgresObject dereferences to tokio_postgres::Client
348 let client: &mut tokio_postgres::Client = pg_obj.as_mut();
349 Ok(func(AnyConnWrapper::Postgres(client)).await)
350 }
351 MiddlewarePoolConnection::Sqlite(_) => {
352 Err(SqlMiddlewareDbError::Unimplemented(
353 "interact_async is not supported for SQLite; use interact_sync instead".to_string()
354 ))
355 }
356 }
357 }
358
359 pub async fn interact_sync<F, R>(&self, f: F) -> Result<R, SqlMiddlewareDbError>
360 where
361 F: FnOnce(AnyConnWrapper) -> R + Send + 'static,
362 R: Send + 'static,
363 {
364 match self {
365 MiddlewarePoolConnection::Sqlite(sqlite_obj) => {
366 // Use `deadpool_sqlite`'s `interact` method
367 sqlite_obj
368 .interact(move |conn| {
369 let wrapper = AnyConnWrapper::Sqlite(conn);
370 Ok(f(wrapper))
371 })
372 .await?
373 }
374 MiddlewarePoolConnection::Postgres(_) => {
375 Err(SqlMiddlewareDbError::Unimplemented(
376 "interact_sync is not supported for Postgres; use interact_async instead".to_string()
377 ))
378 }
379 }
380 }
381}
382
383// ----------------------------------------
384// Common impl blocks for DbError
385// ----------------------------------------
386// We don't need this anymore as thiserror already generates a Display implementation
387// The #[error] attributes on the enum variants define the format for each variant
388// This is commented out to show what was here originally
389/*
390impl fmt::Display for SqlMiddlewareDbError {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 match self {
393 SqlMiddlewareDbError::PostgresError(e) => write!(f, "PostgresError: {}", e),
394 SqlMiddlewareDbError::SqliteError(e) => write!(f, "SqliteError: {}", e),
395 SqlMiddlewareDbError::ConfigError(msg) => write!(f, "ConfigError: {}", msg),
396 SqlMiddlewareDbError::ConnectionError(msg) => write!(f, "ConnectionError: {}", msg),
397 SqlMiddlewareDbError::ParameterError(msg) => write!(f, "ParameterError: {}", msg),
398 SqlMiddlewareDbError::ExecutionError(msg) => write!(f, "ExecutionError: {}", msg),
399 SqlMiddlewareDbError::Unimplemented(msg) => write!(f, "Unimplemented: {}", msg),
400 SqlMiddlewareDbError::Other(msg) => write!(f, "Other: {}", msg),
401 SqlMiddlewareDbError::PoolErrorSqlite(e) => write!(f, "PoolError: {:?}", e),
402 SqlMiddlewareDbError::PoolErrorPostgres(pool_error) => {
403 write!(f, "PoolErrorPostgres: {:?}", pool_error)
404 }
405 }
406 }
407}
408*/
409
410// ----------------------------------------
411// Impl for RowValues that is DB-agnostic
412// ----------------------------------------
413impl RowValues {
414 pub fn as_int(&self) -> Option<&i64> {
415 if let RowValues::Int(value) = self {
416 Some(value)
417 } else {
418 None
419 }
420 }
421
422 pub fn as_text(&self) -> Option<&str> {
423 if let RowValues::Text(value) = self {
424 Some(value)
425 } else {
426 None
427 }
428 }
429
430 pub fn as_bool(&self) -> Option<&bool> {
431 if let RowValues::Bool(value) = self {
432 return Some(value);
433 } else if let Some(i) = self.as_int() {
434 if *i == 1 {
435 return Some(&true);
436 } else if *i == 0 {
437 return Some(&false);
438 }
439 }
440 None
441 }
442
443 pub fn as_timestamp(&self) -> Option<chrono::NaiveDateTime> {
444 if let RowValues::Timestamp(value) = self {
445 return Some(value.clone());
446 } else if let Some(s) = self.as_text() {
447 // Try "YYYY-MM-DD HH:MM:SS"
448 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
449 return Some(dt);
450 }
451 // Try "YYYY-MM-DD HH:MM:SS.SSS"
452 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S.%3f") {
453 return Some(dt);
454 }
455 }
456 None
457 }
458
459 pub fn as_float(&self) -> Option<f64> {
460 if let RowValues::Float(value) = self {
461 Some(*value)
462 } else {
463 None
464 }
465 }
466
467 pub fn as_blob(&self) -> Option<&[u8]> {
468 if let RowValues::Blob(bytes) = self {
469 Some(bytes)
470 } else {
471 None
472 }
473 }
474}
475
476#[async_trait]
477pub trait AsyncDatabaseExecutor {
478 /// Executes a batch of SQL queries (can be a mix of reads/writes) within a transaction. No parameters are supported.
479 async fn execute_batch(&mut self, query: &str) -> Result<(), SqlMiddlewareDbError>;
480
481 /// Executes a single SELECT statement and returns the result set.
482 async fn execute_select(
483 &mut self,
484 query: &str,
485 params: &[RowValues],
486 ) -> Result<ResultSet, SqlMiddlewareDbError>;
487
488 /// Executes a single DML statement (INSERT, UPDATE, DELETE, etc.) and returns the number of rows affected.
489 async fn execute_dml(
490 &mut self,
491 query: &str,
492 params: &[RowValues],
493 ) -> Result<usize, SqlMiddlewareDbError>;
494}
495
496#[async_trait]
497impl AsyncDatabaseExecutor for MiddlewarePoolConnection {
498 /// Executes a batch of SQL queries within a transaction by delegating to the specific database module.
499 async fn execute_batch(&mut self, query: &str) -> Result<(), SqlMiddlewareDbError> {
500 match self {
501 MiddlewarePoolConnection::Postgres(pg_client) => {
502 postgres::execute_batch(pg_client, query).await
503 }
504 MiddlewarePoolConnection::Sqlite(sqlite_client) => {
505 sqlite::execute_batch(sqlite_client, query).await
506 }
507 }
508 }
509 async fn execute_select(
510 &mut self,
511 query: &str,
512 params: &[RowValues],
513 ) -> Result<ResultSet, SqlMiddlewareDbError> {
514 match self {
515 MiddlewarePoolConnection::Postgres(pg_client) => {
516 postgres::execute_select(pg_client, query, params).await
517 }
518 MiddlewarePoolConnection::Sqlite(sqlite_client) => {
519 sqlite::execute_select(sqlite_client, query, params).await
520 }
521 }
522 }
523 async fn execute_dml(
524 &mut self,
525 query: &str,
526 params: &[RowValues],
527 ) -> Result<usize, SqlMiddlewareDbError> {
528 match self {
529 MiddlewarePoolConnection::Postgres(pg_client) => {
530 postgres::execute_dml(pg_client, query, ¶ms).await
531 }
532 MiddlewarePoolConnection::Sqlite(sqlite_client) => {
533 sqlite::execute_dml(sqlite_client, query, params).await
534 }
535 }
536 }
537}
538
539/// Convert a slice of RowValues into database-specific parameters.
540/// This trait provides a unified interface for converting generic RowValues
541/// to database-specific parameter types.
542pub trait ParamConverter<'a> {
543 type Converted;
544
545 /// Convert a slice of RowValues into the backend’s parameter type.
546 fn convert_sql_params(
547 params: &'a [RowValues],
548 mode: ConversionMode,
549 ) -> Result<Self::Converted, SqlMiddlewareDbError>;
550
551 /// Check if this converter supports the given mode
552 ///
553 /// # Arguments
554 /// * `mode` - The conversion mode to check
555 ///
556 /// # Returns
557 /// * `bool` - Whether this converter supports the mode
558 fn supports_mode(_mode: ConversionMode) -> bool {
559 true // By default, support both modes
560 }
561}
562
563/// The conversion "mode".
564#[derive(Debug, Clone, Copy, PartialEq)]
565pub enum ConversionMode {
566 /// When the converted parameters will be used in a query (SELECT)
567 Query,
568 /// When the converted parameters will be used for statement execution (INSERT/UPDATE/etc.)
569 Execute,
570}