things_mcp/core/reader/
pool.rs1use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use rusqlite::{Connection, OpenFlags};
13use tokio::sync::Semaphore;
14
15use crate::core::error::ThingsError;
16
17pub fn open_read_only(db: &Path) -> Result<Connection, ThingsError> {
18 let uri = format!("file:{}?mode=ro&nolock=1&immutable=1", db.to_string_lossy());
19 let conn = Connection::open_with_flags(
20 &uri,
21 OpenFlags::SQLITE_OPEN_READ_ONLY
22 | OpenFlags::SQLITE_OPEN_NO_MUTEX
23 | OpenFlags::SQLITE_OPEN_URI,
24 )?;
25 conn.busy_timeout(std::time::Duration::from_millis(500))?;
26 Ok(conn)
27}
28
29#[derive(Clone)]
30pub struct ReaderPool {
31 inner: Arc<Inner>,
32}
33
34impl std::fmt::Debug for ReaderPool {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ReaderPool")
37 .field("path", &self.inner.path)
38 .finish_non_exhaustive()
39 }
40}
41
42struct Inner {
43 path: PathBuf,
44 sem: Semaphore,
45}
46
47impl ReaderPool {
48 pub async fn new(db_path: PathBuf, max: usize) -> Result<Self, ThingsError> {
49 let _probe = open_read_only(&db_path)?;
51 Ok(Self {
52 inner: Arc::new(Inner {
53 path: db_path,
54 sem: Semaphore::new(max),
55 }),
56 })
57 }
58
59 pub fn db_path(&self) -> &Path {
60 &self.inner.path
61 }
62
63 pub async fn with_conn<F, R>(&self, f: F) -> Result<R, ThingsError>
64 where
65 F: FnOnce(&Connection) -> rusqlite::Result<R> + Send + 'static,
66 R: Send + 'static,
67 {
68 let permit = self
69 .inner
70 .sem
71 .acquire()
72 .await
73 .map_err(|e| ThingsError::Sqlite(format!("semaphore closed: {e}")))?;
74 let path = self.inner.path.clone();
75 let result = tokio::task::spawn_blocking(move || -> Result<R, ThingsError> {
76 let conn = open_read_only(&path)?;
77 f(&conn).map_err(ThingsError::from)
78 })
79 .await
80 .map_err(|e| ThingsError::Sqlite(format!("join: {e}")))?;
81 drop(permit);
82 result
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::core::reader::fixture::build_fixture;
90 use tempfile::tempdir;
91
92 #[tokio::test]
93 async fn pool_opens_and_runs_a_query() {
94 let tmp = tempdir().unwrap();
95 let path = tmp.path().join("p.sqlite");
96 build_fixture(&path).unwrap();
97 let pool = ReaderPool::new(path, 2).await.unwrap();
98 let n: i64 = pool
99 .with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTask", [], |r| r.get(0)))
100 .await
101 .unwrap();
102 assert_eq!(n, 15);
103 }
104
105 #[tokio::test]
106 async fn pool_caps_concurrency() {
107 let tmp = tempdir().unwrap();
109 let path = tmp.path().join("p.sqlite");
110 build_fixture(&path).unwrap();
111 let pool = ReaderPool::new(path, 2).await.unwrap();
112 let p = pool.clone();
113 let h1 = tokio::spawn(async move {
114 p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTask", [], |r| r.get::<_, i64>(0)))
115 .await
116 });
117 let p = pool.clone();
118 let h2 = tokio::spawn(async move {
119 p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMTag", [], |r| r.get::<_, i64>(0)))
120 .await
121 });
122 let p = pool.clone();
123 let h3 = tokio::spawn(async move {
124 p.with_conn(|c| c.query_row("SELECT COUNT(*) FROM TMArea", [], |r| r.get::<_, i64>(0)))
125 .await
126 });
127 for h in [h1, h2, h3] {
128 h.await.unwrap().unwrap();
129 }
130 }
131}