rusticx/
connection.rs

1use std::fmt::Debug;
2use std::sync::{Arc, Mutex};
3use crate::error::RusticxError;
4use crate::model::SQLModel;
5use crate::transaction_manager::TransactionExecutor;
6
7// Conditional includes based on feature flags
8#[cfg(feature = "mysql")]
9use crate::transaction_manager::{run_mysql_transaction, mysql};
10#[cfg(feature = "rusqlite")]
11use crate::transaction_manager::{run_sqlite_transaction, rusqlite};
12#[cfg(feature = "postgres")]
13use crate::transaction_manager::{run_postgres_transaction, tokio_postgres};
14#[cfg(feature = "postgres")]
15use postgres::types::ToSql;
16use tokio::runtime::Runtime;
17#[cfg(feature = "mysql")]
18use mysql::prelude::Queryable;
19
20/// Represents the type of database being used.
21#[derive(Debug, Clone)]
22pub enum DatabaseType {
23    /// PostgreSQL database type.
24    PostgreSQL,
25    /// MySQL database type.
26    MySQL,
27    /// SQLite database type.
28    SQLite,
29}
30
31/// Represents a connection pool for different database types.
32///
33/// This enum holds the specific connection pool or client instance
34/// depending on the enabled database feature. The `None` variant
35/// indicates that no connection has been established yet.
36#[derive(Clone)]
37pub enum ConnectionPool {
38    /// Connection pool/client for PostgreSQL.
39    /// Holds an `Arc<Mutex<tokio_postgres::Client>>` for thread-safe access
40    /// and an `Arc<Runtime>` for managing async operations.
41    #[cfg(feature = "postgres")]
42    PostgreSQL(Arc<Mutex<tokio_postgres::Client>>, Arc<Runtime>),
43    /// Connection pool for MySQL.
44    /// Holds an `Arc<mysql::Pool>`.
45    #[cfg(feature = "mysql")]
46    MySQL(Arc<mysql::Pool>),
47    /// Connection for SQLite.
48    /// Holds an `Arc<Mutex<rusqlite::Connection>>` for thread-safe access.
49    #[cfg(feature = "rusqlite")]
50    SQLite(Arc<Mutex<rusqlite::Connection>>),
51    /// Represents an uninitialized or closed connection pool.
52    None,
53}
54
55/// Represents a database connection with its URL, type, and connection pool.
56///
57/// This struct provides a unified interface for interacting with different
58/// database systems supported by the `rusticx` library.
59#[derive(Clone)]
60pub struct Connection {
61    /// The database connection URL.
62    url: String,
63    /// The type of the database.
64    db_type: DatabaseType,
65    /// The underlying connection pool or client.
66    pool: ConnectionPool,
67}
68
69impl Connection {
70    /// Creates a new `Connection` instance based on the provided database URL.
71    ///
72    /// This function determines the database type from the URL scheme and
73    /// attempts to establish a connection using the appropriate driver.
74    ///
75    /// # Arguments
76    ///
77    /// * `url`: The database connection string (e.g., "postgres://...", "mysql://...", "sqlite://...").
78    ///
79    /// # Returns
80    ///
81    /// Returns a `Result` containing the initialized `Connection` on success,
82    /// or a `RusticxError` if the URL is invalid or connection fails.
83    pub fn new(url: &str) -> Result<Self, RusticxError> {
84        let db_type = if url.starts_with("postgresql://") {
85            DatabaseType::PostgreSQL
86        } else if url.starts_with("mysql://") {
87            DatabaseType::MySQL
88        } else if url.starts_with("sqlite://") {
89            DatabaseType::SQLite
90        } else {
91            return Err(RusticxError::ConnectionError(
92                "Invalid database URL scheme. Must start with postgresql://, mysql://, or sqlite://"
93                    .to_string(),
94            ));
95        };
96
97        let connection = Connection {
98            url: url.to_string(),
99            db_type,
100            pool: ConnectionPool::None, // Initialize with None, connect() will populate
101        };
102
103        // Immediately attempt to connect after determining the type
104        connection.connect()
105    }
106
107    /// Establishes a connection to the database and returns the updated `Connection`.
108    ///
109    /// This internal helper function performs the actual database connection
110    /// based on the determined `DatabaseType` and populates the `pool` field.
111    ///
112    /// # Returns
113    ///
114    /// Returns a `Result` containing the `Connection` with an active pool on success,
115    /// or a `RusticxError` if the connection fails or the database feature is not enabled.
116    fn connect(self) -> Result<Self, RusticxError> {
117        let pool = match self.db_type {
118            #[cfg(feature = "postgres")]
119            DatabaseType::PostgreSQL => {
120                use tokio_postgres::NoTls;
121
122                // Create a dedicated Tokio runtime for blocking async operations
123                let rt = Runtime::new().map_err(|e| {
124                    RusticxError::ConnectionError(format!("Failed to create Tokio runtime: {}", e))
125                })?;
126
127                let (client, connection) = rt
128                    .block_on(async { tokio_postgres::connect(&self.url, NoTls).await })
129                    .map_err(|e| {
130                        RusticxError::ConnectionError(format!("Failed to connect to PostgreSQL: {}", e))
131                    })?;
132
133                // Spawn a task to handle the connection errors asynchronously
134                rt.spawn(async move {
135                    if let Err(e) = connection.await {
136                        eprintln!("PostgreSQL connection error: {}", e);
137                    }
138                });
139
140                ConnectionPool::PostgreSQL(Arc::new(Mutex::new(client)), Arc::new(rt))
141            }
142
143            #[cfg(feature = "mysql")]
144            DatabaseType::MySQL => {
145                let opts = mysql::OptsBuilder::from_opts(
146                    mysql::Opts::from_url(&self.url)
147                        .map_err(|e| RusticxError::ConnectionError(format!("Invalid MySQL URL: {}", e)))?,
148                );
149                let pool = mysql::Pool::new(opts)
150                    .map_err(|e| RusticxError::ConnectionError(format!("Failed to connect to MySQL: {}", e)))?;
151                ConnectionPool::MySQL(Arc::new(pool))
152            }
153
154            #[cfg(feature = "rusqlite")]
155            DatabaseType::SQLite => {
156                let path = self.url.trim_start_matches("sqlite://");
157                let conn = rusqlite::Connection::open(path).map_err(|e| {
158                    RusticxError::ConnectionError(format!("Failed to connect to SQLite: {}", e))
159                })?;
160                ConnectionPool::SQLite(Arc::new(Mutex::new(conn)))
161            }
162
163            // This pattern is marked unreachable because the initial URL check
164            // should cover all supported types. However, it serves as a fallback
165            // for completeness and handles cases where a feature is not enabled.
166            #[allow(unreachable_patterns)]
167            _ => {
168                return Err(RusticxError::ConnectionError(format!(
169                    "Database type {:?} is not supported or the corresponding feature is not enabled (check Cargo.toml)",
170                    self.db_type
171                )));
172            }
173        };
174
175        Ok(Connection {
176            url: self.url.clone(),
177            db_type: self.db_type.clone(),
178            pool,
179        })
180    }
181
182    /// Creates a table in the database based on the provided SQL model definition.
183    ///
184    /// This function uses the `SQLModel` trait to generate the appropriate
185    /// `CREATE TABLE` SQL statement for the current database type and
186    /// executes it.
187    ///
188    /// # Type Parameters
189    ///
190    /// * `T`: The type representing the SQL model, which must implement `SQLModel`.
191    ///
192    /// # Returns
193    ///
194    /// Returns `Ok(())` on successful table creation, or a `RusticxError`
195    /// if the SQL generation or execution fails.
196    pub fn create_table<T: SQLModel>(&self) -> Result<(), RusticxError> {
197        // The table name is not directly used here, but could be for logging or validation
198        let _table_name = T::table_name();
199        let sql = T::create_table_sql(&self.db_type);
200        self.execute(&sql, &[])?;
201        Ok(())
202    }
203
204    /// Executes a SQL command (INSERT, UPDATE, DELETE, CREATE, DROP, etc.)
205    /// with the provided parameters.
206    ///
207    /// This function is typically used for commands that do not return a result set.
208    /// The number of affected rows (where applicable) is returned.
209    ///
210    /// # Arguments
211    ///
212    /// * `sql`: The SQL query string to execute.
213    /// * `params`: A slice of references to values to be used as query parameters.
214    ///             The specific type required depends on the database driver
215    ///             (e.g., `&(dyn ToSql + Sync + 'static)` for postgres).
216    ///
217    /// # Returns
218    ///
219    /// Returns a `Result` containing the number of rows affected on success,
220    /// or a `RusticxError` if the execution fails. Note that the meaning
221    /// and availability of "rows affected" can vary between database drivers.
222    ///
223    /// # Errors
224    ///
225    /// Returns a `RusticxError::QueryError` on database query execution failure
226    /// or `RusticxError::ConnectionError` if the connection pool is not initialized.
227    pub fn execute(
228        &self,
229        sql: &str,
230        params: &[&(dyn ToSql + Sync + 'static)],
231    ) -> Result<u64, RusticxError> {
232        match &self.pool {
233            #[cfg(feature = "postgres")]
234            ConnectionPool::PostgreSQL(client, rt) => {
235                let client_guard = client.lock().map_err(|e| {
236                    RusticxError::TransactionError(format!("Failed to acquire lock on connection: {}", e))
237                })?;
238
239                let result = rt
240                    .block_on(async { client_guard.execute(sql, params).await })
241                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
242                Ok(result)
243            }
244
245            #[cfg(feature = "mysql")]
246            ConnectionPool::MySQL(pool) => {
247                let mut conn = pool
248                    .get_conn()
249                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
250                // MySQL's `exec_drop` does not reliably return rows affected, returning 1 is a common workaround
251                conn.exec_drop(sql, ())
252                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
253                Ok(1) // Indicate at least one operation was attempted
254            }
255
256            #[cfg(feature = "rusqlite")]
257            ConnectionPool::SQLite(conn) => {
258                let conn_guard = conn.lock().map_err(|e| {
259                    RusticxError::ConnectionError(format!("Failed to acquire lock on SQLite connection: {}", e))
260                })?;
261                let result = conn_guard
262                    .execute(sql, []) // rusqlite requires params as a slice of ToSql, converting &[&dyn ToSql] to &[&dyn ToSql] is complex. Assuming no params for simplicity in this example or adjust signature.
263                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
264                Ok(result as u64)
265            }
266
267            ConnectionPool::None => {
268                Err(RusticxError::ConnectionError(
269                    "No active database connection pool initialized".to_string(),
270                ))
271            }
272
273            // Fallback for unsupported or disabled database types
274            #[allow(unreachable_patterns)]
275            _ => Err(RusticxError::ConnectionError(
276                "Unsupported database type for execute operation".to_string(),
277            )),
278        }
279    }
280
281    /// Executes a raw SQL query (typically SELECT) and returns the results
282    /// as a vector of deserialized objects.
283    ///
284    /// This function queries the database and attempts to map the rows
285    /// from the result set into instances of the specified type `T`.
286    ///
287    /// # Type Parameters
288    ///
289    /// * `T`: The target type to deserialize the rows into. Must implement
290    ///        `serde::Deserialize<'de>` and `Debug`.
291    ///
292    /// # Arguments
293    ///
294    /// * `sql`: The SQL query string (e.g., "SELECT id, name FROM users WHERE age > $1").
295    /// * `params`: A slice of references to values to be used as query parameters.
296    ///             The specific type required depends on the database driver.
297    ///
298    /// # Returns
299    ///
300    /// Returns a `Result` containing a `Vec<T>` on success, where each element
301    /// corresponds to a row from the result set, or a `RusticxError` if the
302    /// query or deserialization fails.
303    ///
304    /// # Errors
305    ///
306    /// Returns a `RusticxError::QueryError` on database query execution failure,
307    /// `RusticxError::SerializationError` if deserialization fails, or
308    /// `RusticxError::ConnectionError` if the connection pool is not initialized.
309    pub fn query_raw<T>(&self, sql: &str, params: &[&(dyn ToSql + Sync + 'static)]) -> Result<Vec<T>, RusticxError>
310    where
311        T: for<'de> serde::Deserialize<'de> + Debug,
312    {
313        match &self.pool {
314            #[cfg(feature = "postgres")]
315            ConnectionPool::PostgreSQL(client, rt) => {
316                let client_guard = client.lock().map_err(|e| {
317                    RusticxError::TransactionError(format!("Failed to acquire lock on connection: {}", e))
318                })?;
319                let rows = rt
320                    .block_on(async { client_guard.query(sql, params).await })
321                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
322
323                let mut models = Vec::with_capacity(rows.len());
324                for row in rows {
325                    let mut json_obj = serde_json::Map::new();
326                    for column in row.columns() {
327                        let name = column.name();
328                        // Assuming a helper function exists to convert pg row value to JSON
329                        let value = crate::transaction_manager::pg_row_value_to_json(&row, column)
330                            .unwrap_or(serde_json::Value::Null); // Use Null for unconvertible values
331                        json_obj.insert(name.to_string(), value);
332                    }
333                    let model = serde_json::from_value(serde_json::Value::Object(json_obj))
334                        .map_err(|e| RusticxError::SerializationError(e.to_string()))?;
335                    models.push(model);
336                }
337                Ok(models)
338            }
339
340            #[cfg(feature = "mysql")]
341            ConnectionPool::MySQL(pool) => {
342                let mut conn = pool
343                    .get_conn()
344                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
345
346                // Use query_map to iterate over results and convert
347                let rows: Vec<Result<T, mysql::Error>> = conn
348                    .query_map(sql, |row: mysql::Row| {
349                        let mut json_obj = serde_json::Map::new();
350                        let columns = row.columns_ref();
351
352                        for (i, column) in columns.iter().enumerate() {
353                            let name = column.name_str().to_string();
354                            // Assuming a helper function exists to convert mysql row value to JSON
355                            let value = crate::transaction_manager::mysql_row_value_to_json(
356                                &row,
357                                i,
358                                column.column_type(),
359                            )
360                            .unwrap_or(serde_json::Value::Null);
361                            json_obj.insert(name, value);
362                        }
363
364                        // Deserialize the JSON object into the target struct T
365                        serde_json::from_value(serde_json::Value::Object(json_obj)).map_err(|e| {
366                            // Convert serde_json error to a mysql error for compatibility with query_map
367                            mysql::Error::from(std::io::Error::new(
368                                std::io::ErrorKind::Other,
369                                e.to_string(),
370                            ))
371                        })
372                    })
373                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
374
375                // Collect the results, converting the vector of Results into a single Result<Vec<T>>
376                let result: Vec<T> = rows
377                    .into_iter()
378                    .collect::<Result<_, _>>()
379                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
380
381                Ok(result)
382            }
383
384            #[cfg(feature = "rusqlite")]
385            ConnectionPool::SQLite(conn) => {
386                let conn_guard = conn.lock().map_err(|e| {
387                    RusticxError::ConnectionError(format!("Failed to acquire lock on SQLite connection: {}", e))
388                })?;
389
390                let mut stmt = conn_guard
391                    .prepare(sql)
392                    .map_err(|e| RusticxError::QueryError(e.to_string()))?;
393
394                let column_names: Vec<String> = stmt
395                    .column_names()
396                    .iter()
397                    .map(|name| name.to_string())
398                    .collect();
399
400                let models = stmt
401                    .query_map([], |row| {
402                        // Map each row to a JSON object
403                        let mut json_obj = serde_json::Map::new();
404                        for (i, name) in column_names.iter().enumerate() {
405                            // Assuming a helper function exists to convert sqlite row value to JSON
406                            let value = crate::transaction_manager::sqlite_row_value_to_json(row, i)
407                                .unwrap_or(serde_json::Value::Null);
408                            json_obj.insert(name.clone(), value);
409                        }
410                        // Deserialize the JSON object into the target struct T
411                        serde_json::from_value(serde_json::Value::Object(json_obj)).map_err(
412                            |e| {
413                                // Convert serde_json error to a rusqlite error
414                                rusqlite::Error::FromSqlConversionFailure(
415                                    i, // Column index where error occurred
416                                    rusqlite::types::Type::Text, // Assuming Text type for conversion
417                                    Box::new(e),
418                                )
419                            },
420                        )?;
421                        Ok(model)
422                    })
423                    .map_err(|e| RusticxError::QueryError(e.to_string()))?
424                    .collect::<Result<Vec<_>, _>>()
425                    .map_err(|e| RusticxError::QueryError(e.to_string()))?; // Collect results and handle potential errors
426
427                Ok(models)
428            }
429
430            ConnectionPool::None => {
431                Err(RusticxError::ConnectionError(
432                    "No active database connection pool initialized".to_string(),
433                ))
434            }
435
436            // Fallback for unsupported or disabled database types
437            #[allow(unreachable_patterns)]
438            _ => Err(RusticxError::ConnectionError(
439                "Unsupported database type for query operation".to_string(),
440            )),
441        }
442    }
443
444    /// Executes a database transaction using the provided transaction function.
445    ///
446    /// This function manages the transaction lifecycle (begin, commit/rollback)
447    /// and executes the code defined in the `transaction_fn` closure within
448    /// the transaction's scope. The closure receives a `TransactionExecutor`
449    /// which allows performing database operations within the transaction.
450    ///
451    /// # Type Parameters
452    ///
453    /// * `F`: The type of the closure that defines the transaction logic. Must
454    ///        implement `FnOnce(&dyn TransactionExecutor) -> Result<R, RusticxError>`,
455    ///        `Send`, and `'static`.
456    /// * `R`: The return type of the transaction function. Must implement `Send`
457    ///        and `'static`.
458    ///
459    /// # Arguments
460    ///
461    /// * `transaction_fn`: The closure containing the database operations to be
462    ///                   executed within the transaction.
463    ///
464    /// # Returns
465    ///
466    /// Returns a `Result` containing the value `R` returned by the transaction
467    /// function on successful commit, or a `RusticxError` if the transaction
468    /// fails or is rolled back.
469    ///
470    /// # Errors
471    ///
472    /// Returns `RusticxError::TransactionError` on transaction management failures,
473    /// or `RusticxError::ConnectionError` if the connection pool is not initialized.
474    pub async fn transaction<F, R>(&self, transaction_fn: F) -> Result<R, RusticxError>
475    where
476        F: FnOnce(&dyn TransactionExecutor) -> Result<R, RusticxError> + Send + 'static,
477        R: Send + 'static,
478    {
479        match &self.pool {
480            #[cfg(feature = "postgres")]
481            ConnectionPool::PostgreSQL(client, _) => {
482                // Delegate to the PostgreSQL specific transaction runner
483                run_postgres_transaction(&client.clone(), transaction_fn).await
484            }
485
486            #[cfg(feature = "mysql")]
487            ConnectionPool::MySQL(pool) => {
488                // Delegate to the MySQL specific transaction runner
489                run_mysql_transaction(&pool.clone(), transaction_fn)
490            }
491
492            #[cfg(feature = "rusqlite")]
493            ConnectionPool::SQLite(conn) => {
494                // Delegate to the SQLite specific transaction runner
495                run_sqlite_transaction(&conn.clone(), transaction_fn)
496            }
497
498            ConnectionPool::None => {
499                Err(RusticxError::ConnectionError(
500                    "No active database connection pool initialized for transaction".to_string(),
501                ))
502            }
503
504            // Fallback for unsupported or disabled database types
505            #[allow(unreachable_patterns)]
506            _ => Err(RusticxError::ConnectionError(
507                "Unsupported database type for transaction operation".to_string(),
508            )),
509        }
510    }
511
512    /// Returns a reference to the database type of this connection.
513    ///
514    /// # Returns
515    ///
516    /// A reference to the `DatabaseType` enum indicating the connected database.
517    pub fn get_db_type(&self) -> &DatabaseType {
518        &self.db_type
519    }
520}