1use std::path::Path;
29use std::sync::Arc;
30
31use deadpool_sqlite::{Config, Hook, HookError, Pool, Runtime};
32use solo_core::{Error, Result, VectorIndex};
33
34use crate::key_material::KeyMaterial;
35
36pub const DEFAULT_POOL_SIZE: usize = 2;
38
39#[derive(Clone)]
44pub struct ReaderPool {
45 pool: Pool,
46 hnsw: Arc<dyn VectorIndex + Send + Sync>,
47}
48
49impl ReaderPool {
50 pub fn new(
51 db_path: &Path,
52 key: Option<KeyMaterial>,
53 hnsw: Arc<dyn VectorIndex + Send + Sync>,
54 ) -> Result<Self> {
55 Self::with_size(db_path, key, DEFAULT_POOL_SIZE, hnsw)
56 }
57
58 pub fn with_size(
59 db_path: &Path,
60 key: Option<KeyMaterial>,
61 size: usize,
62 hnsw: Arc<dyn VectorIndex + Send + Sync>,
63 ) -> Result<Self> {
64 let cfg = Config::new(db_path);
65 let mut builder = cfg
66 .builder(Runtime::Tokio1)
67 .map_err(|e| Error::storage(format!("deadpool config: {e:?}")))?
68 .max_size(size);
69
70 if let Some(key) = key {
71 let key_hex = key.as_hex();
72 builder = builder.post_create(Hook::async_fn(move |conn, _metrics| {
73 let key_hex = key_hex.clone();
74 Box::pin(async move {
75 let pragma = format!("PRAGMA key = \"x'{}'\"", &*key_hex);
76 conn.interact(move |c| {
77 c.execute_batch(&pragma)?;
78 c.execute_batch(
79 "PRAGMA foreign_keys = ON;
80 PRAGMA busy_timeout = 5000;",
81 )?;
82 Ok::<_, rusqlite::Error>(())
83 })
84 .await
85 .map_err(|e| HookError::message(format!("interact: {e}")))?
86 .map_err(|e| HookError::message(format!("PRAGMA key: {e}")))?;
87 Ok(())
88 })
89 }));
90 } else {
91 builder = builder.post_create(Hook::async_fn(|conn, _metrics| {
92 Box::pin(async move {
93 conn.interact(|c| {
94 c.execute_batch(
95 "PRAGMA foreign_keys = ON;
96 PRAGMA busy_timeout = 5000;",
97 )
98 })
99 .await
100 .map_err(|e| HookError::message(format!("interact: {e}")))?
101 .map_err(|e| HookError::message(format!("PRAGMA setup: {e}")))?;
102 Ok(())
103 })
104 }));
105 }
106
107 let pool = builder
108 .build()
109 .map_err(|e| Error::storage(format!("deadpool build: {e:?}")))?;
110 Ok(Self { pool, hnsw })
111 }
112
113 pub async fn interact<F, R>(&self, f: F) -> Result<R>
114 where
115 F: FnOnce(&mut rusqlite::Connection) -> rusqlite::Result<R> + Send + 'static,
116 R: Send + 'static,
117 {
118 let conn = self
119 .pool
120 .get()
121 .await
122 .map_err(|e| Error::storage(format!("pool get: {e:?}")))?;
123 conn.interact(f)
124 .await
125 .map_err(|e| Error::storage(format!("interact: {e}")))?
126 .map_err(|e| Error::storage(format!("rusqlite: {e}")))
127 }
128
129 pub fn hnsw(&self) -> &Arc<dyn VectorIndex + Send + Sync> {
130 &self.hnsw
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::test_support::{StubVectorIndex, fixture_embedding, fixture_episode, open_test_db_at};
138 use crate::writer::WriterActor;
139
140 fn rt() -> tokio::runtime::Runtime {
141 tokio::runtime::Builder::new_multi_thread()
142 .worker_threads(2)
143 .enable_all()
144 .build()
145 .unwrap()
146 }
147
148 #[test]
149 fn pool_returns_connections() {
150 let runtime = rt();
151 let tmp = tempfile::TempDir::new().unwrap();
152 let path = tmp.path().join("test.db");
153 let _ = open_test_db_at(&path);
154
155 runtime.block_on(async {
156 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
159 let pool = ReaderPool::new(&path, None, hnsw).unwrap();
160 let n: u32 = pool
161 .interact(|conn| {
162 conn.query_row(
163 "SELECT MAX(version) FROM schema_migrations",
164 [],
165 |row| row.get(0),
166 )
167 })
168 .await
169 .unwrap();
170 assert_eq!(n, 2);
171 });
172 }
173
174 #[test]
175 fn reader_sees_writes_committed_through_writer_actor() {
176 let runtime = rt();
177 let tmp = tempfile::TempDir::new().unwrap();
178 let path = tmp.path().join("test.db");
179 let writer_conn = open_test_db_at(&path);
180 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
181
182 let crate::writer::WriterSpawn { handle, join: _ } =
184 WriterActor::spawn(writer_conn, hnsw.clone());
185
186 runtime.block_on(async {
187 let pool = ReaderPool::new(&path, None, hnsw).unwrap();
188
189 let episode = fixture_episode("reader-visibility test");
190 let mid = handle
191 .remember(episode.clone(), fixture_embedding(4))
192 .await
193 .unwrap();
194 assert_eq!(mid, episode.memory_id);
195
196 let mid_str = mid.to_string();
197 let content: String = pool
198 .interact(move |conn| {
199 conn.query_row(
200 "SELECT content FROM episodes WHERE memory_id = ?",
201 [mid_str],
202 |row| row.get(0),
203 )
204 })
205 .await
206 .unwrap();
207 assert_eq!(content, "reader-visibility test");
208 });
209
210 drop(handle);
211 }
212
213 #[test]
214 fn many_concurrent_reads_serve_from_pool() {
215 let runtime = rt();
216 let tmp = tempfile::TempDir::new().unwrap();
217 let path = tmp.path().join("test.db");
218 let _ = open_test_db_at(&path);
219
220 runtime.block_on(async {
221 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
222 let pool = ReaderPool::with_size(&path, None, 4, hnsw).unwrap();
223
224 let mut tasks = Vec::new();
225 for _ in 0..32 {
226 let p = pool.pool.clone();
227 tasks.push(tokio::spawn(async move {
228 let conn = p.get().await.unwrap();
229 conn.interact(|c| {
230 c.query_row(
231 "SELECT MAX(version) FROM schema_migrations",
232 [],
233 |row| row.get::<_, u32>(0),
234 )
235 })
236 .await
237 .unwrap()
238 .unwrap()
239 }));
240 }
241 let mut counts = Vec::new();
242 for t in tasks {
243 counts.push(t.await.unwrap());
244 }
245 assert_eq!(counts.len(), 32);
246 assert!(counts.iter().all(|c| *c == 2));
247 });
248 }
249}