tempo_cli/db/
pool.rs

1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::collections::VecDeque;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// A database connection with metadata
9#[derive(Debug)]
10pub struct PooledConnection {
11    pub connection: Connection,
12    created_at: Instant,
13    last_used: Instant,
14    use_count: usize,
15}
16
17impl PooledConnection {
18    fn new(connection: Connection) -> Self {
19        let now = Instant::now();
20        Self {
21            connection,
22            created_at: now,
23            last_used: now,
24            use_count: 0,
25        }
26    }
27
28    fn mark_used(&mut self) {
29        self.last_used = Instant::now();
30        self.use_count += 1;
31    }
32
33    fn is_expired(&self, max_lifetime: Duration) -> bool {
34        self.created_at.elapsed() > max_lifetime
35    }
36
37    fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38        self.last_used.elapsed() > max_idle
39    }
40}
41
42/// Configuration for the database pool
43#[derive(Debug, Clone)]
44pub struct PoolConfig {
45    pub max_connections: usize,
46    pub min_connections: usize,
47    pub max_lifetime: Duration,
48    pub max_idle_time: Duration,
49    pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53    fn default() -> Self {
54        Self {
55            max_connections: 10,
56            min_connections: 2,
57            max_lifetime: Duration::from_secs(3600), // 1 hour
58            max_idle_time: Duration::from_secs(600), // 10 minutes
59            connection_timeout: Duration::from_secs(30),
60        }
61    }
62}
63
64/// A connection pool for SQLite databases
65pub struct DatabasePool {
66    db_path: PathBuf,
67    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68    config: PoolConfig,
69    stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74    pub total_connections_created: usize,
75    pub active_connections: usize,
76    pub connections_in_pool: usize,
77    pub connection_requests: usize,
78    pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82    /// Create a new database pool
83    pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84        let db_path = db_path.as_ref().to_path_buf();
85
86        // Create parent directory if it doesn't exist
87        if let Some(parent) = db_path.parent() {
88            std::fs::create_dir_all(parent)?;
89        }
90
91        let pool = Self {
92            db_path,
93            pool: Arc::new(Mutex::new(VecDeque::new())),
94            config,
95            stats: Arc::new(Mutex::new(PoolStats::default())),
96        };
97
98        // Pre-populate with minimum connections
99        pool.ensure_min_connections()?;
100
101        Ok(pool)
102    }
103
104    /// Create a new database pool with default configuration
105    pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106        Self::new(db_path, PoolConfig::default())
107    }
108
109    /// Get a connection from the pool
110    pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111        let start = Instant::now();
112
113        // Update stats
114        {
115            let mut stats = self
116                .stats
117                .lock()
118                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
119            stats.connection_requests += 1;
120        }
121
122        loop {
123            // Try to get a connection from the pool
124            if let Some(mut conn) = self.try_get_from_pool()? {
125                conn.mark_used();
126
127                // Update stats
128                {
129                    let mut stats = self
130                        .stats
131                        .lock()
132                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
133                    stats.active_connections += 1;
134                    stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
135                }
136
137                return Ok(PooledConnectionGuard::new(
138                    conn,
139                    self.pool.clone(),
140                    self.stats.clone(),
141                ));
142            }
143
144            // If no connection available, try to create a new one
145            if self.can_create_new_connection()? {
146                let conn = self.create_connection()?;
147
148                // Update stats
149                {
150                    let mut stats = self
151                        .stats
152                        .lock()
153                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
154                    stats.total_connections_created += 1;
155                    stats.active_connections += 1;
156                }
157
158                return Ok(PooledConnectionGuard::new(
159                    conn,
160                    self.pool.clone(),
161                    self.stats.clone(),
162                ));
163            }
164
165            // Check for timeout
166            if start.elapsed() > self.config.connection_timeout {
167                let mut stats = self
168                    .stats
169                    .lock()
170                    .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
171                stats.connection_timeouts += 1;
172                return Err(anyhow::anyhow!(
173                    "Connection timeout after {:?}",
174                    self.config.connection_timeout
175                ));
176            }
177
178            // Wait a bit before retrying
179            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
180        }
181    }
182
183    /// Try to get a connection from the existing pool
184    fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
185        let mut pool = self
186            .pool
187            .lock()
188            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
189
190        // Clean up expired/idle connections first
191        self.cleanup_connections(&mut pool)?;
192
193        // Try to get a connection
194        Ok(pool.pop_front())
195    }
196
197    /// Check if we can create a new connection
198    fn can_create_new_connection(&self) -> Result<bool> {
199        let stats = self
200            .stats
201            .lock()
202            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
203        Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
204    }
205
206    /// Create a new database connection
207    fn create_connection(&self) -> Result<PooledConnection> {
208        let connection = Connection::open_with_flags(
209            &self.db_path,
210            OpenFlags::SQLITE_OPEN_READ_WRITE
211                | OpenFlags::SQLITE_OPEN_CREATE
212                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
213        )?;
214
215        // Configure the connection
216        connection.pragma_update(None, "foreign_keys", "ON")?;
217        connection.pragma_update(None, "journal_mode", "WAL")?;
218        connection.pragma_update(None, "synchronous", "NORMAL")?;
219        connection.pragma_update(None, "cache_size", "-64000")?;
220
221        // Run migrations
222        crate::db::migrations::run_migrations(&connection)?;
223
224        Ok(PooledConnection::new(connection))
225    }
226
227    /// Clean up expired and idle connections
228    fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
229        let mut to_remove = Vec::new();
230
231        for (index, conn) in pool.iter().enumerate() {
232            if conn.is_expired(self.config.max_lifetime)
233                || conn.is_idle_too_long(self.config.max_idle_time)
234            {
235                to_remove.push(index);
236            }
237        }
238
239        // Remove connections in reverse order to maintain indices
240        for index in to_remove.iter().rev() {
241            pool.remove(*index);
242        }
243
244        Ok(())
245    }
246
247    /// Ensure minimum number of connections are available
248    fn ensure_min_connections(&self) -> Result<()> {
249        let mut pool = self
250            .pool
251            .lock()
252            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
253
254        while pool.len() < self.config.min_connections {
255            let conn = self.create_connection()?;
256            pool.push_back(conn);
257
258            // Update stats
259            let mut stats = self
260                .stats
261                .lock()
262                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
263            stats.total_connections_created += 1;
264            stats.connections_in_pool += 1;
265        }
266
267        Ok(())
268    }
269
270    /// Return a connection to the pool
271    fn return_connection(&self, conn: PooledConnection) -> Result<()> {
272        let mut pool = self
273            .pool
274            .lock()
275            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
276
277        // Check if we should keep this connection
278        if !conn.is_expired(self.config.max_lifetime) && pool.len() < self.config.max_connections {
279            pool.push_back(conn);
280
281            // Update stats
282            let mut stats = self
283                .stats
284                .lock()
285                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
286            stats.connections_in_pool += 1;
287            stats.active_connections = stats.active_connections.saturating_sub(1);
288        } else {
289            // Update stats - connection is being dropped
290            let mut stats = self
291                .stats
292                .lock()
293                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
294            stats.active_connections = stats.active_connections.saturating_sub(1);
295        }
296
297        Ok(())
298    }
299
300    /// Get current pool statistics
301    pub fn stats(&self) -> Result<PoolStats> {
302        let stats = self
303            .stats
304            .lock()
305            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
306        Ok(PoolStats {
307            total_connections_created: stats.total_connections_created,
308            active_connections: stats.active_connections,
309            connections_in_pool: stats.connections_in_pool,
310            connection_requests: stats.connection_requests,
311            connection_timeouts: stats.connection_timeouts,
312        })
313    }
314
315    /// Close all connections in the pool
316    pub fn close(&self) -> Result<()> {
317        let mut pool = self
318            .pool
319            .lock()
320            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
321        pool.clear();
322
323        let mut stats = self
324            .stats
325            .lock()
326            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
327        stats.connections_in_pool = 0;
328
329        Ok(())
330    }
331}
332
333/// A guard that automatically returns connections to the pool when dropped
334pub struct PooledConnectionGuard {
335    connection: Option<PooledConnection>,
336    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
337    stats: Arc<Mutex<PoolStats>>,
338}
339
340impl PooledConnectionGuard {
341    fn new(
342        connection: PooledConnection,
343        pool: Arc<Mutex<VecDeque<PooledConnection>>>,
344        stats: Arc<Mutex<PoolStats>>,
345    ) -> Self {
346        Self {
347            connection: Some(connection),
348            pool,
349            stats,
350        }
351    }
352
353    /// Get a reference to the underlying connection
354    pub fn connection(&self) -> &Connection {
355        &self.connection.as_ref().unwrap().connection
356    }
357}
358
359impl Drop for PooledConnectionGuard {
360    fn drop(&mut self) {
361        if let Some(conn) = self.connection.take() {
362            // Try to return connection to pool
363            let mut pool = match self.pool.lock() {
364                Ok(pool) => pool,
365                Err(_) => {
366                    // Pool lock is poisoned, just update stats
367                    if let Ok(mut stats) = self.stats.lock() {
368                        stats.active_connections = stats.active_connections.saturating_sub(1);
369                    }
370                    return;
371                }
372            };
373
374            // Check if we should keep this connection
375            if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
376                pool.push_back(conn);
377                if let Ok(mut stats) = self.stats.lock() {
378                    stats.connections_in_pool += 1;
379                    stats.active_connections = stats.active_connections.saturating_sub(1);
380                }
381            } else {
382                // Connection is being dropped
383                if let Ok(mut stats) = self.stats.lock() {
384                    stats.active_connections = stats.active_connections.saturating_sub(1);
385                }
386            }
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use tempfile::tempdir;
395
396    #[test]
397    fn test_pool_creation() {
398        let temp_dir = tempdir().unwrap();
399        let db_path = temp_dir.path().join("test.db");
400
401        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
402        let stats = pool.stats().unwrap();
403
404        // Should have minimum connections created
405        assert!(stats.total_connections_created >= 2);
406        assert_eq!(stats.connections_in_pool, 2);
407    }
408
409    #[tokio::test]
410    async fn test_get_connection() {
411        let temp_dir = tempdir().unwrap();
412        let db_path = temp_dir.path().join("test.db");
413
414        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
415        let conn = pool.get_connection().await.unwrap();
416
417        // Should be able to use the connection
418        conn.connection()
419            .execute("CREATE TABLE test (id INTEGER)", [])
420            .unwrap();
421
422        let stats = pool.stats().unwrap();
423        assert_eq!(stats.active_connections, 1);
424    }
425
426    #[tokio::test]
427    async fn test_connection_return() {
428        let temp_dir = tempdir().unwrap();
429        let db_path = temp_dir.path().join("test.db");
430
431        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
432
433        {
434            let _conn = pool.get_connection().await.unwrap();
435            let stats = pool.stats().unwrap();
436            assert_eq!(stats.active_connections, 1);
437        }
438
439        // Connection should be returned to pool
440        let stats = pool.stats().unwrap();
441        assert_eq!(stats.active_connections, 0);
442        assert!(stats.connections_in_pool > 0);
443    }
444}