tempo_cli/db/
pool.rs

1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::collections::VecDeque;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// A database connection with metadata
9#[derive(Debug)]
10pub struct PooledConnection {
11    pub connection: Connection,
12    created_at: Instant,
13    last_used: Instant,
14    use_count: usize,
15}
16
17impl PooledConnection {
18    fn new(connection: Connection) -> Self {
19        let now = Instant::now();
20        Self {
21            connection,
22            created_at: now,
23            last_used: now,
24            use_count: 0,
25        }
26    }
27
28    fn mark_used(&mut self) {
29        self.last_used = Instant::now();
30        self.use_count += 1;
31    }
32
33    fn is_expired(&self, max_lifetime: Duration) -> bool {
34        self.created_at.elapsed() > max_lifetime
35    }
36
37    fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38        self.last_used.elapsed() > max_idle
39    }
40}
41
42/// Configuration for the database pool
43#[derive(Debug, Clone)]
44pub struct PoolConfig {
45    pub max_connections: usize,
46    pub min_connections: usize,
47    pub max_lifetime: Duration,
48    pub max_idle_time: Duration,
49    pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53    fn default() -> Self {
54        Self {
55            max_connections: 10,
56            min_connections: 2,
57            max_lifetime: Duration::from_secs(3600), // 1 hour
58            max_idle_time: Duration::from_secs(600), // 10 minutes
59            connection_timeout: Duration::from_secs(30),
60        }
61    }
62}
63
64/// A connection pool for SQLite databases
65pub struct DatabasePool {
66    db_path: PathBuf,
67    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68    config: PoolConfig,
69    stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74    pub total_connections_created: usize,
75    pub active_connections: usize,
76    pub connections_in_pool: usize,
77    pub connection_requests: usize,
78    pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82    /// Create a new database pool
83    pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84        let db_path = db_path.as_ref().to_path_buf();
85
86        // Create parent directory if it doesn't exist
87        if let Some(parent) = db_path.parent() {
88            std::fs::create_dir_all(parent)?;
89        }
90
91        let pool = Self {
92            db_path,
93            pool: Arc::new(Mutex::new(VecDeque::new())),
94            config,
95            stats: Arc::new(Mutex::new(PoolStats::default())),
96        };
97
98        // Pre-populate with minimum connections
99        pool.ensure_min_connections()?;
100
101        Ok(pool)
102    }
103
104    /// Create a new database pool with default configuration
105    pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106        Self::new(db_path, PoolConfig::default())
107    }
108
109    /// Get a connection from the pool
110    pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111        let start = Instant::now();
112
113        // Update stats
114        {
115            let mut stats = self
116                .stats
117                .lock()
118                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
119            stats.connection_requests += 1;
120        }
121
122        loop {
123            // Try to get a connection from the pool
124            if let Some(mut conn) = self.try_get_from_pool()? {
125                conn.mark_used();
126
127                // Update stats
128                {
129                    let mut stats = self
130                        .stats
131                        .lock()
132                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
133                    stats.active_connections += 1;
134                    stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
135                }
136
137                return Ok(PooledConnectionGuard::new(
138                    conn,
139                    self.pool.clone(),
140                    self.stats.clone(),
141                ));
142            }
143
144            // If no connection available, try to create a new one
145            if self.can_create_new_connection()? {
146                let conn = self.create_connection()?;
147
148                // Update stats
149                {
150                    let mut stats = self
151                        .stats
152                        .lock()
153                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
154                    stats.total_connections_created += 1;
155                    stats.active_connections += 1;
156                }
157
158                return Ok(PooledConnectionGuard::new(
159                    conn,
160                    self.pool.clone(),
161                    self.stats.clone(),
162                ));
163            }
164
165            // Check for timeout
166            if start.elapsed() > self.config.connection_timeout {
167                let mut stats = self
168                    .stats
169                    .lock()
170                    .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
171                stats.connection_timeouts += 1;
172                return Err(anyhow::anyhow!(
173                    "Connection timeout after {:?}",
174                    self.config.connection_timeout
175                ));
176            }
177
178            // Wait a bit before retrying
179            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
180        }
181    }
182
183    /// Try to get a connection from the existing pool
184    fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
185        let mut pool = self
186            .pool
187            .lock()
188            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
189
190        // Clean up expired/idle connections first
191        self.cleanup_connections(&mut pool)?;
192
193        // Try to get a connection
194        Ok(pool.pop_front())
195    }
196
197    /// Check if we can create a new connection
198    fn can_create_new_connection(&self) -> Result<bool> {
199        let stats = self
200            .stats
201            .lock()
202            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
203        Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
204    }
205
206    /// Create a new database connection
207    fn create_connection(&self) -> Result<PooledConnection> {
208        let connection = Connection::open_with_flags(
209            &self.db_path,
210            OpenFlags::SQLITE_OPEN_READ_WRITE
211                | OpenFlags::SQLITE_OPEN_CREATE
212                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
213        )?;
214
215        // Configure the connection
216        connection.pragma_update(None, "foreign_keys", "ON")?;
217        connection.pragma_update(None, "journal_mode", "WAL")?;
218        connection.pragma_update(None, "synchronous", "NORMAL")?;
219        connection.pragma_update(None, "cache_size", "-64000")?;
220
221        // Run migrations
222        crate::db::migrations::run_migrations(&connection)?;
223
224        Ok(PooledConnection::new(connection))
225    }
226
227    /// Clean up expired and idle connections
228    fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
229        let mut to_remove = Vec::new();
230
231        for (index, conn) in pool.iter().enumerate() {
232            if conn.is_expired(self.config.max_lifetime)
233                || conn.is_idle_too_long(self.config.max_idle_time)
234            {
235                to_remove.push(index);
236            }
237        }
238
239        // Remove connections in reverse order to maintain indices
240        for index in to_remove.iter().rev() {
241            pool.remove(*index);
242        }
243
244        Ok(())
245    }
246
247    /// Ensure minimum number of connections are available
248    fn ensure_min_connections(&self) -> Result<()> {
249        let mut pool = self
250            .pool
251            .lock()
252            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
253
254        while pool.len() < self.config.min_connections {
255            let conn = self.create_connection()?;
256            pool.push_back(conn);
257
258            // Update stats
259            let mut stats = self
260                .stats
261                .lock()
262                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
263            stats.total_connections_created += 1;
264            stats.connections_in_pool += 1;
265        }
266
267        Ok(())
268    }
269
270    /// Return a connection to the pool
271    #[allow(dead_code)]
272    fn return_connection(&self, conn: PooledConnection) -> Result<()> {
273        let mut pool = self
274            .pool
275            .lock()
276            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
277
278        // Check if we should keep this connection
279        if !conn.is_expired(self.config.max_lifetime) && pool.len() < self.config.max_connections {
280            pool.push_back(conn);
281
282            // Update stats
283            let mut stats = self
284                .stats
285                .lock()
286                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
287            stats.connections_in_pool += 1;
288            stats.active_connections = stats.active_connections.saturating_sub(1);
289        } else {
290            // Update stats - connection is being dropped
291            let mut stats = self
292                .stats
293                .lock()
294                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
295            stats.active_connections = stats.active_connections.saturating_sub(1);
296        }
297
298        Ok(())
299    }
300
301    /// Get current pool statistics
302    pub fn stats(&self) -> Result<PoolStats> {
303        let stats = self
304            .stats
305            .lock()
306            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
307        Ok(PoolStats {
308            total_connections_created: stats.total_connections_created,
309            active_connections: stats.active_connections,
310            connections_in_pool: stats.connections_in_pool,
311            connection_requests: stats.connection_requests,
312            connection_timeouts: stats.connection_timeouts,
313        })
314    }
315
316    /// Close all connections in the pool
317    pub fn close(&self) -> Result<()> {
318        let mut pool = self
319            .pool
320            .lock()
321            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
322        pool.clear();
323
324        let mut stats = self
325            .stats
326            .lock()
327            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
328        stats.connections_in_pool = 0;
329
330        Ok(())
331    }
332}
333
334/// A guard that automatically returns connections to the pool when dropped
335pub struct PooledConnectionGuard {
336    connection: Option<PooledConnection>,
337    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
338    stats: Arc<Mutex<PoolStats>>,
339}
340
341impl PooledConnectionGuard {
342    fn new(
343        connection: PooledConnection,
344        pool: Arc<Mutex<VecDeque<PooledConnection>>>,
345        stats: Arc<Mutex<PoolStats>>,
346    ) -> Self {
347        Self {
348            connection: Some(connection),
349            pool,
350            stats,
351        }
352    }
353
354    /// Get a reference to the underlying connection
355    pub fn connection(&self) -> &Connection {
356        &self.connection.as_ref().unwrap().connection
357    }
358}
359
360impl Drop for PooledConnectionGuard {
361    fn drop(&mut self) {
362        if let Some(conn) = self.connection.take() {
363            // Try to return connection to pool
364            let mut pool = match self.pool.lock() {
365                Ok(pool) => pool,
366                Err(_) => {
367                    // Pool lock is poisoned, just update stats
368                    if let Ok(mut stats) = self.stats.lock() {
369                        stats.active_connections = stats.active_connections.saturating_sub(1);
370                    }
371                    return;
372                }
373            };
374
375            // Check if we should keep this connection
376            if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
377                pool.push_back(conn);
378                if let Ok(mut stats) = self.stats.lock() {
379                    stats.connections_in_pool += 1;
380                    stats.active_connections = stats.active_connections.saturating_sub(1);
381                }
382            } else {
383                // Connection is being dropped
384                if let Ok(mut stats) = self.stats.lock() {
385                    stats.active_connections = stats.active_connections.saturating_sub(1);
386                }
387            }
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use tempfile::tempdir;
396
397    #[test]
398    fn test_pool_creation() {
399        let temp_dir = tempdir().unwrap();
400        let db_path = temp_dir.path().join("test.db");
401
402        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
403        let stats = pool.stats().unwrap();
404
405        // Should have minimum connections created
406        assert!(stats.total_connections_created >= 2);
407        assert_eq!(stats.connections_in_pool, 2);
408    }
409
410    #[tokio::test]
411    async fn test_get_connection() {
412        let temp_dir = tempdir().unwrap();
413        let db_path = temp_dir.path().join("test.db");
414
415        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
416        let conn = pool.get_connection().await.unwrap();
417
418        // Should be able to use the connection
419        conn.connection()
420            .execute("CREATE TABLE test (id INTEGER)", [])
421            .unwrap();
422
423        let stats = pool.stats().unwrap();
424        assert_eq!(stats.active_connections, 1);
425    }
426
427    #[tokio::test]
428    async fn test_connection_return() {
429        let temp_dir = tempdir().unwrap();
430        let db_path = temp_dir.path().join("test.db");
431
432        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
433
434        {
435            let _conn = pool.get_connection().await.unwrap();
436            let stats = pool.stats().unwrap();
437            assert_eq!(stats.active_connections, 1);
438        }
439
440        // Connection should be returned to pool
441        let stats = pool.stats().unwrap();
442        assert_eq!(stats.active_connections, 0);
443        assert!(stats.connections_in_pool > 0);
444    }
445}