Skip to main content

things_mcp/core/reader/
pool.rs

1//! Semaphore-throttled, short-lived RO connection pool.
2//!
3//! Mirrors `zotero-connector`'s pattern: bound concurrent readers with a
4//! `tokio::sync::Semaphore`, open a fresh `Connection` per `with_conn` call
5//! using URI flags (`mode=ro`, `nolock=1`, `immutable=1`), run the closure
6//! inside `spawn_blocking`. Each call picks up Things' latest committed state
7//! automatically because the connection lifetime spans only one query.
8
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use rusqlite::{Connection, OpenFlags};
13use tokio::sync::Semaphore;
14
15use crate::core::error::ThingsError;
16
17pub fn open_read_only(db: &Path) -> Result<Connection, ThingsError> {
18    let uri = format!("file:{}?mode=ro&nolock=1&immutable=1", db.to_string_lossy());
19    let conn = Connection::open_with_flags(
20        &uri,
21        OpenFlags::SQLITE_OPEN_READ_ONLY
22            | OpenFlags::SQLITE_OPEN_NO_MUTEX
23            | OpenFlags::SQLITE_OPEN_URI,
24    )?;
25    conn.busy_timeout(std::time::Duration::from_millis(500))?;
26    Ok(conn)
27}
28
29#[derive(Clone)]
30pub struct ReaderPool {
31    inner: Arc<Inner>,
32}
33
34impl std::fmt::Debug for ReaderPool {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ReaderPool")
37            .field("path", &self.inner.path)
38            .finish_non_exhaustive()
39    }
40}
41
42struct Inner {
43    path: PathBuf,
44    sem: Semaphore,
45}
46
47impl ReaderPool {
48    pub async fn new(db_path: PathBuf, max: usize) -> Result<Self, ThingsError> {
49        // Validate the path + permissions up front.
50        let _probe = open_read_only(&db_path)?;
51        Ok(Self {
52            inner: Arc::new(Inner {
53                path: db_path,
54                sem: Semaphore::new(max),
55            }),
56        })
57    }
58
59    pub fn db_path(&self) -> &Path {
60        &self.inner.path
61    }
62
63    pub async fn with_conn<F, R>(&self, f: F) -> Result<R, ThingsError>
64    where
65        F: FnOnce(&Connection) -> rusqlite::Result<R> + Send + 'static,
66        R: Send + 'static,
67    {
68        let permit = self
69            .inner
70            .sem
71            .acquire()
72            .await
73            .map_err(|e| ThingsError::Sqlite(format!("semaphore closed: {e}")))?;
74        let path = self.inner.path.clone();
75        let result = tokio::task::spawn_blocking(move || -> Result<R, ThingsError> {
76            let conn = open_read_only(&path)?;
77            f(&conn).map_err(ThingsError::from)
78        })
79        .await
80        .map_err(|e| ThingsError::Sqlite(format!("join: {e}")))?;
81        drop(permit);
82        result
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::core::reader::fixture::build_fixture;
90    use tempfile::tempdir;
91
92    #[tokio::test]
93    async fn pool_opens_and_runs_a_query() {
94        let tmp = tempdir().unwrap();
95        let path = tmp.path().join("p.sqlite");
96        build_fixture(&path).unwrap();
97        let pool = ReaderPool::new(path, 2).await.unwrap();
98        let n: i64 = pool
99            .with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTask", [], |r| r.get(0)))
100            .await
101            .unwrap();
102        assert_eq!(n, 15);
103    }
104
105    #[tokio::test]
106    async fn pool_caps_concurrency() {
107        // Two permits; three concurrent queries should serialise the third.
108        let tmp = tempdir().unwrap();
109        let path = tmp.path().join("p.sqlite");
110        build_fixture(&path).unwrap();
111        let pool = ReaderPool::new(path, 2).await.unwrap();
112        let p = pool.clone();
113        let h1 = tokio::spawn(async move {
114            p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTask", [], |r| r.get::<_, i64>(0)))
115                .await
116        });
117        let p = pool.clone();
118        let h2 = tokio::spawn(async move {
119            p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTag", [], |r| r.get::<_, i64>(0)))
120                .await
121        });
122        let p = pool.clone();
123        let h3 = tokio::spawn(async move {
124            p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMArea", [], |r| r.get::<_, i64>(0)))
125                .await
126        });
127        for h in [h1, h2, h3] {
128            h.await.unwrap().unwrap();
129        }
130    }
131}