rusticx/connection.rs
1use std::fmt::Debug;
2use std::sync::{Arc, Mutex};
3use crate::error::RusticxError;
4use crate::model::SQLModel;
5use crate::transaction_manager::TransactionExecutor;
6
7// Conditional includes based on feature flags
8#[cfg(feature = "mysql")]
9use crate::transaction_manager::{run_mysql_transaction, mysql};
10#[cfg(feature = "rusqlite")]
11use crate::transaction_manager::{run_sqlite_transaction, rusqlite};
12#[cfg(feature = "postgres")]
13use crate::transaction_manager::{run_postgres_transaction, tokio_postgres};
14#[cfg(feature = "postgres")]
15use postgres::types::ToSql;
16use tokio::runtime::Runtime;
17#[cfg(feature = "mysql")]
18use mysql::prelude::Queryable;
19
20/// Represents the type of database being used.
21#[derive(Debug, Clone)]
22pub enum DatabaseType {
23 /// PostgreSQL database type.
24 PostgreSQL,
25 /// MySQL database type.
26 MySQL,
27 /// SQLite database type.
28 SQLite,
29}
30
31/// Represents a connection pool for different database types.
32///
33/// This enum holds the specific connection pool or client instance
34/// depending on the enabled database feature. The `None` variant
35/// indicates that no connection has been established yet.
36#[derive(Clone)]
37pub enum ConnectionPool {
38 /// Connection pool/client for PostgreSQL.
39 /// Holds an `Arc<Mutex<tokio_postgres::Client>>` for thread-safe access
40 /// and an `Arc<Runtime>` for managing async operations.
41 #[cfg(feature = "postgres")]
42 PostgreSQL(Arc<Mutex<tokio_postgres::Client>>, Arc<Runtime>),
43 /// Connection pool for MySQL.
44 /// Holds an `Arc<mysql::Pool>`.
45 #[cfg(feature = "mysql")]
46 MySQL(Arc<mysql::Pool>),
47 /// Connection for SQLite.
48 /// Holds an `Arc<Mutex<rusqlite::Connection>>` for thread-safe access.
49 #[cfg(feature = "rusqlite")]
50 SQLite(Arc<Mutex<rusqlite::Connection>>),
51 /// Represents an uninitialized or closed connection pool.
52 None,
53}
54
55/// Represents a database connection with its URL, type, and connection pool.
56///
57/// This struct provides a unified interface for interacting with different
58/// database systems supported by the `rusticx` library.
59#[derive(Clone)]
60pub struct Connection {
61 /// The database connection URL.
62 url: String,
63 /// The type of the database.
64 db_type: DatabaseType,
65 /// The underlying connection pool or client.
66 pool: ConnectionPool,
67}
68
69impl Connection {
70 /// Creates a new `Connection` instance based on the provided database URL.
71 ///
72 /// This function determines the database type from the URL scheme and
73 /// attempts to establish a connection using the appropriate driver.
74 ///
75 /// # Arguments
76 ///
77 /// * `url`: The database connection string (e.g., "postgres://...", "mysql://...", "sqlite://...").
78 ///
79 /// # Returns
80 ///
81 /// Returns a `Result` containing the initialized `Connection` on success,
82 /// or a `RusticxError` if the URL is invalid or connection fails.
83 pub fn new(url: &str) -> Result<Self, RusticxError> {
84 let db_type = if url.starts_with("postgresql://") {
85 DatabaseType::PostgreSQL
86 } else if url.starts_with("mysql://") {
87 DatabaseType::MySQL
88 } else if url.starts_with("sqlite://") {
89 DatabaseType::SQLite
90 } else {
91 return Err(RusticxError::ConnectionError(
92 "Invalid database URL scheme. Must start with postgresql://, mysql://, or sqlite://"
93 .to_string(),
94 ));
95 };
96
97 let connection = Connection {
98 url: url.to_string(),
99 db_type,
100 pool: ConnectionPool::None, // Initialize with None, connect() will populate
101 };
102
103 // Immediately attempt to connect after determining the type
104 connection.connect()
105 }
106
107 /// Establishes a connection to the database and returns the updated `Connection`.
108 ///
109 /// This internal helper function performs the actual database connection
110 /// based on the determined `DatabaseType` and populates the `pool` field.
111 ///
112 /// # Returns
113 ///
114 /// Returns a `Result` containing the `Connection` with an active pool on success,
115 /// or a `RusticxError` if the connection fails or the database feature is not enabled.
116 fn connect(self) -> Result<Self, RusticxError> {
117 let pool = match self.db_type {
118 #[cfg(feature = "postgres")]
119 DatabaseType::PostgreSQL => {
120 use tokio_postgres::NoTls;
121
122 // Create a dedicated Tokio runtime for blocking async operations
123 let rt = Runtime::new().map_err(|e| {
124 RusticxError::ConnectionError(format!("Failed to create Tokio runtime: {}", e))
125 })?;
126
127 let (client, connection) = rt
128 .block_on(async { tokio_postgres::connect(&self.url, NoTls).await })
129 .map_err(|e| {
130 RusticxError::ConnectionError(format!("Failed to connect to PostgreSQL: {}", e))
131 })?;
132
133 // Spawn a task to handle the connection errors asynchronously
134 rt.spawn(async move {
135 if let Err(e) = connection.await {
136 eprintln!("PostgreSQL connection error: {}", e);
137 }
138 });
139
140 ConnectionPool::PostgreSQL(Arc::new(Mutex::new(client)), Arc::new(rt))
141 }
142
143 #[cfg(feature = "mysql")]
144 DatabaseType::MySQL => {
145 let opts = mysql::OptsBuilder::from_opts(
146 mysql::Opts::from_url(&self.url)
147 .map_err(|e| RusticxError::ConnectionError(format!("Invalid MySQL URL: {}", e)))?,
148 );
149 let pool = mysql::Pool::new(opts)
150 .map_err(|e| RusticxError::ConnectionError(format!("Failed to connect to MySQL: {}", e)))?;
151 ConnectionPool::MySQL(Arc::new(pool))
152 }
153
154 #[cfg(feature = "rusqlite")]
155 DatabaseType::SQLite => {
156 let path = self.url.trim_start_matches("sqlite://");
157 let conn = rusqlite::Connection::open(path).map_err(|e| {
158 RusticxError::ConnectionError(format!("Failed to connect to SQLite: {}", e))
159 })?;
160 ConnectionPool::SQLite(Arc::new(Mutex::new(conn)))
161 }
162
163 // This pattern is marked unreachable because the initial URL check
164 // should cover all supported types. However, it serves as a fallback
165 // for completeness and handles cases where a feature is not enabled.
166 #[allow(unreachable_patterns)]
167 _ => {
168 return Err(RusticxError::ConnectionError(format!(
169 "Database type {:?} is not supported or the corresponding feature is not enabled (check Cargo.toml)",
170 self.db_type
171 )));
172 }
173 };
174
175 Ok(Connection {
176 url: self.url.clone(),
177 db_type: self.db_type.clone(),
178 pool,
179 })
180 }
181
182 /// Creates a table in the database based on the provided SQL model definition.
183 ///
184 /// This function uses the `SQLModel` trait to generate the appropriate
185 /// `CREATE TABLE` SQL statement for the current database type and
186 /// executes it.
187 ///
188 /// # Type Parameters
189 ///
190 /// * `T`: The type representing the SQL model, which must implement `SQLModel`.
191 ///
192 /// # Returns
193 ///
194 /// Returns `Ok(())` on successful table creation, or a `RusticxError`
195 /// if the SQL generation or execution fails.
196 pub fn create_table<T: SQLModel>(&self) -> Result<(), RusticxError> {
197 // The table name is not directly used here, but could be for logging or validation
198 let _table_name = T::table_name();
199 let sql = T::create_table_sql(&self.db_type);
200 self.execute(&sql, &[])?;
201 Ok(())
202 }
203
204 /// Executes a SQL command (INSERT, UPDATE, DELETE, CREATE, DROP, etc.)
205 /// with the provided parameters.
206 ///
207 /// This function is typically used for commands that do not return a result set.
208 /// The number of affected rows (where applicable) is returned.
209 ///
210 /// # Arguments
211 ///
212 /// * `sql`: The SQL query string to execute.
213 /// * `params`: A slice of references to values to be used as query parameters.
214 /// The specific type required depends on the database driver
215 /// (e.g., `&(dyn ToSql + Sync + 'static)` for postgres).
216 ///
217 /// # Returns
218 ///
219 /// Returns a `Result` containing the number of rows affected on success,
220 /// or a `RusticxError` if the execution fails. Note that the meaning
221 /// and availability of "rows affected" can vary between database drivers.
222 ///
223 /// # Errors
224 ///
225 /// Returns a `RusticxError::QueryError` on database query execution failure
226 /// or `RusticxError::ConnectionError` if the connection pool is not initialized.
227 pub fn execute(
228 &self,
229 sql: &str,
230 params: &[&(dyn ToSql + Sync + 'static)],
231 ) -> Result<u64, RusticxError> {
232 match &self.pool {
233 #[cfg(feature = "postgres")]
234 ConnectionPool::PostgreSQL(client, rt) => {
235 let client_guard = client.lock().map_err(|e| {
236 RusticxError::TransactionError(format!("Failed to acquire lock on connection: {}", e))
237 })?;
238
239 let result = rt
240 .block_on(async { client_guard.execute(sql, params).await })
241 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
242 Ok(result)
243 }
244
245 #[cfg(feature = "mysql")]
246 ConnectionPool::MySQL(pool) => {
247 let mut conn = pool
248 .get_conn()
249 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
250 // MySQL's `exec_drop` does not reliably return rows affected, returning 1 is a common workaround
251 conn.exec_drop(sql, ())
252 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
253 Ok(1) // Indicate at least one operation was attempted
254 }
255
256 #[cfg(feature = "rusqlite")]
257 ConnectionPool::SQLite(conn) => {
258 let conn_guard = conn.lock().map_err(|e| {
259 RusticxError::ConnectionError(format!("Failed to acquire lock on SQLite connection: {}", e))
260 })?;
261 let result = conn_guard
262 .execute(sql, []) // rusqlite requires params as a slice of ToSql, converting &[&dyn ToSql] to &[&dyn ToSql] is complex. Assuming no params for simplicity in this example or adjust signature.
263 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
264 Ok(result as u64)
265 }
266
267 ConnectionPool::None => {
268 Err(RusticxError::ConnectionError(
269 "No active database connection pool initialized".to_string(),
270 ))
271 }
272
273 // Fallback for unsupported or disabled database types
274 #[allow(unreachable_patterns)]
275 _ => Err(RusticxError::ConnectionError(
276 "Unsupported database type for execute operation".to_string(),
277 )),
278 }
279 }
280
281 /// Executes a raw SQL query (typically SELECT) and returns the results
282 /// as a vector of deserialized objects.
283 ///
284 /// This function queries the database and attempts to map the rows
285 /// from the result set into instances of the specified type `T`.
286 ///
287 /// # Type Parameters
288 ///
289 /// * `T`: The target type to deserialize the rows into. Must implement
290 /// `serde::Deserialize<'de>` and `Debug`.
291 ///
292 /// # Arguments
293 ///
294 /// * `sql`: The SQL query string (e.g., "SELECT id, name FROM users WHERE age > $1").
295 /// * `params`: A slice of references to values to be used as query parameters.
296 /// The specific type required depends on the database driver.
297 ///
298 /// # Returns
299 ///
300 /// Returns a `Result` containing a `Vec<T>` on success, where each element
301 /// corresponds to a row from the result set, or a `RusticxError` if the
302 /// query or deserialization fails.
303 ///
304 /// # Errors
305 ///
306 /// Returns a `RusticxError::QueryError` on database query execution failure,
307 /// `RusticxError::SerializationError` if deserialization fails, or
308 /// `RusticxError::ConnectionError` if the connection pool is not initialized.
309 pub fn query_raw<T>(&self, sql: &str, params: &[&(dyn ToSql + Sync + 'static)]) -> Result<Vec<T>, RusticxError>
310 where
311 T: for<'de> serde::Deserialize<'de> + Debug,
312 {
313 match &self.pool {
314 #[cfg(feature = "postgres")]
315 ConnectionPool::PostgreSQL(client, rt) => {
316 let client_guard = client.lock().map_err(|e| {
317 RusticxError::TransactionError(format!("Failed to acquire lock on connection: {}", e))
318 })?;
319 let rows = rt
320 .block_on(async { client_guard.query(sql, params).await })
321 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
322
323 let mut models = Vec::with_capacity(rows.len());
324 for row in rows {
325 let mut json_obj = serde_json::Map::new();
326 for column in row.columns() {
327 let name = column.name();
328 // Assuming a helper function exists to convert pg row value to JSON
329 let value = crate::transaction_manager::pg_row_value_to_json(&row, column)
330 .unwrap_or(serde_json::Value::Null); // Use Null for unconvertible values
331 json_obj.insert(name.to_string(), value);
332 }
333 let model = serde_json::from_value(serde_json::Value::Object(json_obj))
334 .map_err(|e| RusticxError::SerializationError(e.to_string()))?;
335 models.push(model);
336 }
337 Ok(models)
338 }
339
340 #[cfg(feature = "mysql")]
341 ConnectionPool::MySQL(pool) => {
342 let mut conn = pool
343 .get_conn()
344 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
345
346 // Use query_map to iterate over results and convert
347 let rows: Vec<Result<T, mysql::Error>> = conn
348 .query_map(sql, |row: mysql::Row| {
349 let mut json_obj = serde_json::Map::new();
350 let columns = row.columns_ref();
351
352 for (i, column) in columns.iter().enumerate() {
353 let name = column.name_str().to_string();
354 // Assuming a helper function exists to convert mysql row value to JSON
355 let value = crate::transaction_manager::mysql_row_value_to_json(
356 &row,
357 i,
358 column.column_type(),
359 )
360 .unwrap_or(serde_json::Value::Null);
361 json_obj.insert(name, value);
362 }
363
364 // Deserialize the JSON object into the target struct T
365 serde_json::from_value(serde_json::Value::Object(json_obj)).map_err(|e| {
366 // Convert serde_json error to a mysql error for compatibility with query_map
367 mysql::Error::from(std::io::Error::new(
368 std::io::ErrorKind::Other,
369 e.to_string(),
370 ))
371 })
372 })
373 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
374
375 // Collect the results, converting the vector of Results into a single Result<Vec<T>>
376 let result: Vec<T> = rows
377 .into_iter()
378 .collect::<Result<_, _>>()
379 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
380
381 Ok(result)
382 }
383
384 #[cfg(feature = "rusqlite")]
385 ConnectionPool::SQLite(conn) => {
386 let conn_guard = conn.lock().map_err(|e| {
387 RusticxError::ConnectionError(format!("Failed to acquire lock on SQLite connection: {}", e))
388 })?;
389
390 let mut stmt = conn_guard
391 .prepare(sql)
392 .map_err(|e| RusticxError::QueryError(e.to_string()))?;
393
394 let column_names: Vec<String> = stmt
395 .column_names()
396 .iter()
397 .map(|name| name.to_string())
398 .collect();
399
400 let models = stmt
401 .query_map([], |row| {
402 // Map each row to a JSON object
403 let mut json_obj = serde_json::Map::new();
404 for (i, name) in column_names.iter().enumerate() {
405 // Assuming a helper function exists to convert sqlite row value to JSON
406 let value = crate::transaction_manager::sqlite_row_value_to_json(row, i)
407 .unwrap_or(serde_json::Value::Null);
408 json_obj.insert(name.clone(), value);
409 }
410 // Deserialize the JSON object into the target struct T
411 serde_json::from_value(serde_json::Value::Object(json_obj)).map_err(
412 |e| {
413 // Convert serde_json error to a rusqlite error
414 rusqlite::Error::FromSqlConversionFailure(
415 i, // Column index where error occurred
416 rusqlite::types::Type::Text, // Assuming Text type for conversion
417 Box::new(e),
418 )
419 },
420 )?;
421 Ok(model)
422 })
423 .map_err(|e| RusticxError::QueryError(e.to_string()))?
424 .collect::<Result<Vec<_>, _>>()
425 .map_err(|e| RusticxError::QueryError(e.to_string()))?; // Collect results and handle potential errors
426
427 Ok(models)
428 }
429
430 ConnectionPool::None => {
431 Err(RusticxError::ConnectionError(
432 "No active database connection pool initialized".to_string(),
433 ))
434 }
435
436 // Fallback for unsupported or disabled database types
437 #[allow(unreachable_patterns)]
438 _ => Err(RusticxError::ConnectionError(
439 "Unsupported database type for query operation".to_string(),
440 )),
441 }
442 }
443
444 /// Executes a database transaction using the provided transaction function.
445 ///
446 /// This function manages the transaction lifecycle (begin, commit/rollback)
447 /// and executes the code defined in the `transaction_fn` closure within
448 /// the transaction's scope. The closure receives a `TransactionExecutor`
449 /// which allows performing database operations within the transaction.
450 ///
451 /// # Type Parameters
452 ///
453 /// * `F`: The type of the closure that defines the transaction logic. Must
454 /// implement `FnOnce(&dyn TransactionExecutor) -> Result<R, RusticxError>`,
455 /// `Send`, and `'static`.
456 /// * `R`: The return type of the transaction function. Must implement `Send`
457 /// and `'static`.
458 ///
459 /// # Arguments
460 ///
461 /// * `transaction_fn`: The closure containing the database operations to be
462 /// executed within the transaction.
463 ///
464 /// # Returns
465 ///
466 /// Returns a `Result` containing the value `R` returned by the transaction
467 /// function on successful commit, or a `RusticxError` if the transaction
468 /// fails or is rolled back.
469 ///
470 /// # Errors
471 ///
472 /// Returns `RusticxError::TransactionError` on transaction management failures,
473 /// or `RusticxError::ConnectionError` if the connection pool is not initialized.
474 pub async fn transaction<F, R>(&self, transaction_fn: F) -> Result<R, RusticxError>
475 where
476 F: FnOnce(&dyn TransactionExecutor) -> Result<R, RusticxError> + Send + 'static,
477 R: Send + 'static,
478 {
479 match &self.pool {
480 #[cfg(feature = "postgres")]
481 ConnectionPool::PostgreSQL(client, _) => {
482 // Delegate to the PostgreSQL specific transaction runner
483 run_postgres_transaction(&client.clone(), transaction_fn).await
484 }
485
486 #[cfg(feature = "mysql")]
487 ConnectionPool::MySQL(pool) => {
488 // Delegate to the MySQL specific transaction runner
489 run_mysql_transaction(&pool.clone(), transaction_fn)
490 }
491
492 #[cfg(feature = "rusqlite")]
493 ConnectionPool::SQLite(conn) => {
494 // Delegate to the SQLite specific transaction runner
495 run_sqlite_transaction(&conn.clone(), transaction_fn)
496 }
497
498 ConnectionPool::None => {
499 Err(RusticxError::ConnectionError(
500 "No active database connection pool initialized for transaction".to_string(),
501 ))
502 }
503
504 // Fallback for unsupported or disabled database types
505 #[allow(unreachable_patterns)]
506 _ => Err(RusticxError::ConnectionError(
507 "Unsupported database type for transaction operation".to_string(),
508 )),
509 }
510 }
511
512 /// Returns a reference to the database type of this connection.
513 ///
514 /// # Returns
515 ///
516 /// A reference to the `DatabaseType` enum indicating the connected database.
517 pub fn get_db_type(&self) -> &DatabaseType {
518 &self.db_type
519 }
520}