1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use std::collections::VecDeque;
7
8#[derive(Debug)]
10pub struct PooledConnection {
11 pub connection: Connection,
12 created_at: Instant,
13 last_used: Instant,
14 use_count: usize,
15}
16
17impl PooledConnection {
18 fn new(connection: Connection) -> Self {
19 let now = Instant::now();
20 Self {
21 connection,
22 created_at: now,
23 last_used: now,
24 use_count: 0,
25 }
26 }
27
28 fn mark_used(&mut self) {
29 self.last_used = Instant::now();
30 self.use_count += 1;
31 }
32
33 fn is_expired(&self, max_lifetime: Duration) -> bool {
34 self.created_at.elapsed() > max_lifetime
35 }
36
37 fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38 self.last_used.elapsed() > max_idle
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct PoolConfig {
45 pub max_connections: usize,
46 pub min_connections: usize,
47 pub max_lifetime: Duration,
48 pub max_idle_time: Duration,
49 pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53 fn default() -> Self {
54 Self {
55 max_connections: 10,
56 min_connections: 2,
57 max_lifetime: Duration::from_secs(3600), max_idle_time: Duration::from_secs(600), connection_timeout: Duration::from_secs(30),
60 }
61 }
62}
63
64pub struct DatabasePool {
66 db_path: PathBuf,
67 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68 config: PoolConfig,
69 stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74 pub total_connections_created: usize,
75 pub active_connections: usize,
76 pub connections_in_pool: usize,
77 pub connection_requests: usize,
78 pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82 pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84 let db_path = db_path.as_ref().to_path_buf();
85
86 if let Some(parent) = db_path.parent() {
88 std::fs::create_dir_all(parent)?;
89 }
90
91 let pool = Self {
92 db_path,
93 pool: Arc::new(Mutex::new(VecDeque::new())),
94 config,
95 stats: Arc::new(Mutex::new(PoolStats::default())),
96 };
97
98 pool.ensure_min_connections()?;
100
101 Ok(pool)
102 }
103
104 pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106 Self::new(db_path, PoolConfig::default())
107 }
108
109 pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111 let start = Instant::now();
112
113 {
115 let mut stats = self.stats.lock()
116 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
117 stats.connection_requests += 1;
118 }
119
120 loop {
121 if let Some(mut conn) = self.try_get_from_pool()? {
123 conn.mark_used();
124
125 {
127 let mut stats = self.stats.lock()
128 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
129 stats.active_connections += 1;
130 stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
131 }
132
133 return Ok(PooledConnectionGuard::new(conn, self.pool.clone(), self.stats.clone()));
134 }
135
136 if self.can_create_new_connection()? {
138 let conn = self.create_connection()?;
139
140 {
142 let mut stats = self.stats.lock()
143 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
144 stats.total_connections_created += 1;
145 stats.active_connections += 1;
146 }
147
148 return Ok(PooledConnectionGuard::new(conn, self.pool.clone(), self.stats.clone()));
149 }
150
151 if start.elapsed() > self.config.connection_timeout {
153 let mut stats = self.stats.lock()
154 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
155 stats.connection_timeouts += 1;
156 return Err(anyhow::anyhow!("Connection timeout after {:?}", self.config.connection_timeout));
157 }
158
159 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
161 }
162 }
163
164 fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
166 let mut pool = self.pool.lock()
167 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
168
169 self.cleanup_connections(&mut pool)?;
171
172 Ok(pool.pop_front())
174 }
175
176 fn can_create_new_connection(&self) -> Result<bool> {
178 let stats = self.stats.lock()
179 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
180 Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
181 }
182
183 fn create_connection(&self) -> Result<PooledConnection> {
185 let connection = Connection::open_with_flags(
186 &self.db_path,
187 OpenFlags::SQLITE_OPEN_READ_WRITE
188 | OpenFlags::SQLITE_OPEN_CREATE
189 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
190 )?;
191
192 connection.pragma_update(None, "foreign_keys", "ON")?;
194 connection.pragma_update(None, "journal_mode", "WAL")?;
195 connection.pragma_update(None, "synchronous", "NORMAL")?;
196 connection.pragma_update(None, "cache_size", "-64000")?;
197
198 crate::db::migrations::run_migrations(&connection)?;
200
201 Ok(PooledConnection::new(connection))
202 }
203
204 fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
206 let mut to_remove = Vec::new();
207
208 for (index, conn) in pool.iter().enumerate() {
209 if conn.is_expired(self.config.max_lifetime) ||
210 conn.is_idle_too_long(self.config.max_idle_time) {
211 to_remove.push(index);
212 }
213 }
214
215 for index in to_remove.iter().rev() {
217 pool.remove(*index);
218 }
219
220 Ok(())
221 }
222
223 fn ensure_min_connections(&self) -> Result<()> {
225 let mut pool = self.pool.lock()
226 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
227
228 while pool.len() < self.config.min_connections {
229 let conn = self.create_connection()?;
230 pool.push_back(conn);
231
232 let mut stats = self.stats.lock()
234 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
235 stats.total_connections_created += 1;
236 stats.connections_in_pool += 1;
237 }
238
239 Ok(())
240 }
241
242 fn return_connection(&self, conn: PooledConnection) -> Result<()> {
244 let mut pool = self.pool.lock()
245 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
246
247 if !conn.is_expired(self.config.max_lifetime) &&
249 pool.len() < self.config.max_connections {
250 pool.push_back(conn);
251
252 let mut stats = self.stats.lock()
254 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
255 stats.connections_in_pool += 1;
256 stats.active_connections = stats.active_connections.saturating_sub(1);
257 } else {
258 let mut stats = self.stats.lock()
260 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
261 stats.active_connections = stats.active_connections.saturating_sub(1);
262 }
263
264 Ok(())
265 }
266
267 pub fn stats(&self) -> Result<PoolStats> {
269 let stats = self.stats.lock()
270 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
271 Ok(PoolStats {
272 total_connections_created: stats.total_connections_created,
273 active_connections: stats.active_connections,
274 connections_in_pool: stats.connections_in_pool,
275 connection_requests: stats.connection_requests,
276 connection_timeouts: stats.connection_timeouts,
277 })
278 }
279
280 pub fn close(&self) -> Result<()> {
282 let mut pool = self.pool.lock()
283 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
284 pool.clear();
285
286 let mut stats = self.stats.lock()
287 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
288 stats.connections_in_pool = 0;
289
290 Ok(())
291 }
292}
293
294pub struct PooledConnectionGuard {
296 connection: Option<PooledConnection>,
297 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
298 stats: Arc<Mutex<PoolStats>>,
299}
300
301impl PooledConnectionGuard {
302 fn new(
303 connection: PooledConnection,
304 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
305 stats: Arc<Mutex<PoolStats>>,
306 ) -> Self {
307 Self {
308 connection: Some(connection),
309 pool,
310 stats,
311 }
312 }
313
314 pub fn connection(&self) -> &Connection {
316 &self.connection.as_ref().unwrap().connection
317 }
318}
319
320impl Drop for PooledConnectionGuard {
321 fn drop(&mut self) {
322 if let Some(conn) = self.connection.take() {
323 let mut pool = match self.pool.lock() {
325 Ok(pool) => pool,
326 Err(_) => {
327 if let Ok(mut stats) = self.stats.lock() {
329 stats.active_connections = stats.active_connections.saturating_sub(1);
330 }
331 return;
332 }
333 };
334
335 if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
337 pool.push_back(conn);
338 if let Ok(mut stats) = self.stats.lock() {
339 stats.connections_in_pool += 1;
340 stats.active_connections = stats.active_connections.saturating_sub(1);
341 }
342 } else {
343 if let Ok(mut stats) = self.stats.lock() {
345 stats.active_connections = stats.active_connections.saturating_sub(1);
346 }
347 }
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use tempfile::tempdir;
356
357 #[test]
358 fn test_pool_creation() {
359 let temp_dir = tempdir().unwrap();
360 let db_path = temp_dir.path().join("test.db");
361
362 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
363 let stats = pool.stats().unwrap();
364
365 assert!(stats.total_connections_created >= 2);
367 assert_eq!(stats.connections_in_pool, 2);
368 }
369
370 #[tokio::test]
371 async fn test_get_connection() {
372 let temp_dir = tempdir().unwrap();
373 let db_path = temp_dir.path().join("test.db");
374
375 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
376 let conn = pool.get_connection().await.unwrap();
377
378 conn.connection().execute("CREATE TABLE test (id INTEGER)", []).unwrap();
380
381 let stats = pool.stats().unwrap();
382 assert_eq!(stats.active_connections, 1);
383 }
384
385 #[tokio::test]
386 async fn test_connection_return() {
387 let temp_dir = tempdir().unwrap();
388 let db_path = temp_dir.path().join("test.db");
389
390 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
391
392 {
393 let _conn = pool.get_connection().await.unwrap();
394 let stats = pool.stats().unwrap();
395 assert_eq!(stats.active_connections, 1);
396 }
397
398 let stats = pool.stats().unwrap();
400 assert_eq!(stats.active_connections, 0);
401 assert!(stats.connections_in_pool > 0);
402 }
403}