rustmemodb/connection/
mod.rs1pub mod auth;
2pub mod pool;
3pub mod config;
4
5use crate::core::{DbError, Result};
6use crate::facade::InMemoryDB;
7use crate::result::QueryResult;
8use crate::transaction::TransactionId;
9use std::sync::{Arc};
10use tokio::sync::RwLock;
11use auth::{User};
12
13pub struct Connection {
18 id: u64,
20 user: User,
22 db: Arc<RwLock<InMemoryDB>>,
24 state: ConnectionState,
26 transaction_id: Option<TransactionId>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31enum ConnectionState {
32 Active,
33 InTransaction,
34 Closed,
35}
36
37impl Connection {
38 pub(crate) fn new(id: u64, user: User, db: Arc<RwLock<InMemoryDB>>) -> Self {
40 Self {
41 id,
42 user,
43 db,
44 state: ConnectionState::Active,
45 transaction_id: None,
46 }
47 }
48
49 pub fn id(&self) -> u64 {
51 self.id
52 }
53
54 pub fn username(&self) -> &str {
56 self.user.username()
57 }
58
59 pub async fn execute(&mut self, sql: &str) -> Result<QueryResult> {
61 if self.state == ConnectionState::Closed {
62 return Err(DbError::ExecutionError("Connection is closed".into()));
63 }
64
65 let trimmed = sql.trim().to_uppercase();
67 if trimmed == "BEGIN" || trimmed == "BEGIN TRANSACTION" || trimmed == "START TRANSACTION" {
68 self.begin().await?;
69 return Ok(QueryResult::empty_with_message("Transaction started".to_string()));
70 }
71 if trimmed == "COMMIT" || trimmed == "COMMIT TRANSACTION" {
72 self.commit().await?;
73 return Ok(QueryResult::empty_with_message("Transaction committed".to_string()));
74 }
75 if trimmed == "ROLLBACK" || trimmed == "ROLLBACK TRANSACTION" {
76 self.rollback().await?;
77 return Ok(QueryResult::empty_with_message("Transaction rolled back".to_string()));
78 }
79
80 let mut db = self.db.write().await;
81 db.execute_with_transaction(sql, self.transaction_id).await
82 }
83
84 pub async fn query(&mut self, sql: &str) -> Result<QueryResult> {
88 self.execute(sql).await
89 }
90
91 pub async fn exec(&mut self, sql: &str) -> Result<u64> {
95 let result = self.execute(sql).await?;
96 Ok(result.row_count() as u64)
97 }
98
99 pub async fn begin(&mut self) -> Result<()> {
101 if self.state == ConnectionState::Closed {
102 return Err(DbError::ExecutionError("Connection is closed".into()));
103 }
104
105 if self.state == ConnectionState::InTransaction {
106 return Err(DbError::ExecutionError("Transaction already active".into()));
107 }
108
109 let txn_id = {
111 let db = self.db.read().await;
112 db.transaction_manager().begin().await?
113 };
114
115 self.state = ConnectionState::InTransaction;
116 self.transaction_id = Some(txn_id);
117
118 Ok(())
119 }
120
121 pub async fn commit(&mut self) -> Result<()> {
123 if self.state != ConnectionState::InTransaction {
124 return Err(DbError::ExecutionError("No active transaction".into()));
125 }
126
127 let txn_id = self.transaction_id.expect("Transaction ID must be set in InTransaction state");
128
129 {
131 let db = self.db.write().await;
132 let txn_mgr = Arc::clone(db.transaction_manager());
133 txn_mgr.commit(txn_id).await?;
134 }
135
136 self.state = ConnectionState::Active;
137 self.transaction_id = None;
138
139 Ok(())
140 }
141
142 pub async fn rollback(&mut self) -> Result<()> {
144 if self.state != ConnectionState::InTransaction {
145 return Ok(());
147 }
148
149 let txn_id = self.transaction_id.expect("Transaction ID must be set in InTransaction state");
150
151 {
153 let mut db = self.db.write().await;
154 let txn_mgr = Arc::clone(db.transaction_manager());
155 let storage = db.storage_mut();
156 txn_mgr.rollback_with_storage(txn_id, storage).await?;
157 }
158
159 self.state = ConnectionState::Active;
160 self.transaction_id = None;
161
162 Ok(())
163 }
164
165 pub fn is_in_transaction(&self) -> bool {
167 self.state == ConnectionState::InTransaction
168 }
169
170 pub fn is_active(&self) -> bool {
172 self.state != ConnectionState::Closed
173 }
174
175 pub async fn close(&mut self) -> Result<()> {
177 if self.state == ConnectionState::InTransaction {
178 self.rollback().await?;
179 }
180
181 self.state = ConnectionState::Closed;
182 Ok(())
183 }
184
185 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement> {
187 if self.state == ConnectionState::Closed {
188 return Err(DbError::ExecutionError("Connection is closed".into()));
189 }
190
191 Ok(PreparedStatement {
192 sql: sql.to_string(),
193 db: Arc::clone(&self.db),
194 })
195 }
196}
197
198impl Drop for Connection {
199 fn drop(&mut self) {
200 if self.state == ConnectionState::InTransaction {
201 eprintln!("Warning: Connection dropped while in transaction. Transaction may hang. Use connection.close() or commit/rollback explicitly.");
204 }
205 self.state = ConnectionState::Closed;
206 }
207}
208
209pub struct PreparedStatement {
213 sql: String,
214 db: Arc<RwLock<InMemoryDB>>,
215}
216
217impl PreparedStatement {
218 pub async fn execute(&self, _params: &[&dyn std::fmt::Display]) -> Result<QueryResult> {
220 let mut db = self.db.write().await;
221 db.execute(&self.sql).await
222 }
223
224 pub fn sql(&self) -> &str {
226 &self.sql
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 async fn create_test_connection() -> Connection {
235 let db = Arc::new(RwLock::new(InMemoryDB::new()));
236 let user = User::new("test_user".to_string(), "hash".to_string(), Vec::new());
237 Connection::new(1, user, db)
238 }
239
240 #[tokio::test]
241 async fn test_connection_creation() {
242 let conn = create_test_connection().await;
243 assert_eq!(conn.id(), 1);
244 assert_eq!(conn.username(), "test_user");
245 assert!(conn.is_active());
246 assert!(!conn.is_in_transaction());
247 }
248
249 #[tokio::test]
250 async fn test_transaction_lifecycle() {
251 let mut conn = create_test_connection().await;
252
253 assert!(conn.begin().await.is_ok());
254 assert!(conn.is_in_transaction());
255
256 assert!(conn.commit().await.is_ok());
257 assert!(!conn.is_in_transaction());
258 }
259
260 #[tokio::test]
261 async fn test_transaction_rollback() {
262 let mut conn = create_test_connection().await;
263
264 assert!(conn.begin().await.is_ok());
265 assert!(conn.is_in_transaction());
266
267 assert!(conn.rollback().await.is_ok());
268 assert!(!conn.is_in_transaction());
269 }
270
271 #[tokio::test]
272 async fn test_connection_close() {
273 let mut conn = create_test_connection().await;
274
275 assert!(conn.close().await.is_ok());
276 assert!(!conn.is_active());
277
278 assert!(conn.execute("SELECT 1").await.is_err());
280 }
281}