1use crate::drivers::Driver;
2use parking_lot::{Condvar, Mutex};
3#[cfg(feature = "watcher")]
4use sqlite_watcher::connection::State;
5#[cfg(feature = "watcher")]
6use sqlite_watcher::watcher::Watcher;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10
11pub struct ConnectionPoolConfig {
12 pub max_read_connection_count: usize,
13 pub file_path: PathBuf,
14 pub connection_acquire_timeout: Option<Duration>,
15 #[cfg(feature = "watcher")]
16 pub watcher: Arc<Watcher>,
17}
18
19pub struct ConnectionPool<T: Driver, A: ConnectionAdapter<T>> {
20 read_connections: Mutex<Vec<A>>,
21 reader_condvar: Condvar,
22 write_connection: Mutex<WatchedConnection<T>>,
23 config: ConnectionPoolConfig,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum ConnectionPoolError<E> {
28 #[error(transparent)]
29 Driver(#[from] E),
30 #[error("Failed to acquire connection in time")]
31 ConnectionAcquireTimeout,
32 #[error("Failed to setup connection watcher")]
33 WatcherSetup,
34}
35
36impl<T: Driver, A: ConnectionAdapter<T>> ConnectionPool<T, A> {
37 pub fn new(
45 config: ConnectionPoolConfig,
46 ) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
47 let watched_connection = T::new_write_connection(&config.file_path)
48 .inspect_err(|e| tracing::error!("Failed to create write connection: {e:?}"))?;
49 #[cfg(feature = "watcher")]
50 let watched_connection = WatchedConnection::new(watched_connection).map_err(|e| {
51 tracing::error!("Failed to setup connection watcher: {e:?}");
52 ConnectionPoolError::WatcherSetup
53 })?;
54 #[cfg(not(feature = "watcher"))]
55 let watched_connection = WatchedConnection::new(watched_connection);
56
57 let mut read_connections = Vec::with_capacity(config.max_read_connection_count);
58 for _ in 0..config.max_read_connection_count {
59 read_connections.push(A::from_driver_connection(
60 T::new_read_connection(&config.file_path)
61 .inspect_err(|e| tracing::error!("Failed to create read connection: {e:?}"))?,
62 ));
63 }
64 Ok(Arc::new(Self {
65 write_connection: Mutex::new(watched_connection),
66 read_connections: Mutex::new(read_connections),
67 reader_condvar: Condvar::new(),
68 config,
69 }))
70 }
71
72 pub fn connection(
83 self: &Arc<Self>,
84 ) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
85 let mut rd_connections = self.read_connections.lock();
86 loop {
87 if let Some(rd_connection) = rd_connections.pop() {
88 return Ok(PooledConnection::new(self.clone(), rd_connection));
89 } else if let Some(duration) = self.config.connection_acquire_timeout {
90 if self
91 .reader_condvar
92 .wait_for(&mut rd_connections, duration)
93 .timed_out()
94 {
95 return Err(ConnectionPoolError::ConnectionAcquireTimeout);
96 }
97 } else {
98 self.reader_condvar.wait(&mut rd_connections);
99 }
100 }
101 }
102
103 pub(crate) fn transaction<F, R, E>(&self, closure: F) -> Result<R, E>
104 where
105 F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
106 E: From<T::Error>,
107 {
108 let mut writer_connection = self.write_connection.lock();
109 #[cfg(feature = "watcher")]
110 {
111 writer_connection.transaction(closure, &self.config.watcher)
112 }
113 #[cfg(not(feature = "watcher"))]
114 {
115 writer_connection.transaction(closure)
116 }
117 }
118
119 fn return_to_pool(&self, conn: A) {
120 let mut read_connections = self.read_connections.lock();
121 read_connections.push(conn);
122 drop(read_connections);
123 self.reader_condvar.notify_one();
124 }
125
126 #[cfg(feature = "watcher")]
127 pub fn watcher(&self) -> &Arc<Watcher> {
128 &self.config.watcher
129 }
130}
131
132pub trait ConnectionAdapter<T: Driver> {
133 fn from_driver_connection(connection: T::Connection) -> Self;
134}
135
136pub struct PooledConnection<T: Driver, A: ConnectionAdapter<T>> {
137 pub(crate) pool: Arc<ConnectionPool<T, A>>,
138 conn: Option<A>,
139}
140
141impl<T: Driver, A: ConnectionAdapter<T>> Drop for PooledConnection<T, A> {
142 fn drop(&mut self) {
143 let conn = self.conn.take().expect("Connection should be set");
144 self.pool.return_to_pool(conn);
145 }
146}
147
148impl<T: Driver, A: ConnectionAdapter<T>> PooledConnection<T, A> {
149 fn new(pool: Arc<ConnectionPool<T, A>>, connection: A) -> PooledConnection<T, A> {
150 Self {
151 pool,
152 conn: Some(connection),
153 }
154 }
155
156 pub(crate) fn connection(&self) -> &A {
157 self.conn.as_ref().expect("Connection should be set")
158 }
159
160 pub(crate) fn connection_mut(&mut self) -> &mut A {
161 self.conn.as_mut().expect("Connection should be set")
162 }
163}
164
165struct WatchedConnection<T>
166where
167 T: Driver,
168{
169 connection: T::Connection,
170 #[cfg(feature = "watcher")]
171 state: State,
172}
173
174#[cfg(feature = "watcher")]
175impl<T> WatchedConnection<T>
176where
177 T: Driver,
178{
179 fn new(mut connection: T::Connection) -> Result<Self, <T as Driver>::Error> {
180 use sqlite_watcher::statement::Statement;
181 State::start_tracking().execute_mut(&mut connection)?;
182 Ok(Self {
183 connection,
184 state: State::new(),
185 })
186 }
187
188 fn transaction<F, R, E>(&mut self, closure: F, watcher: &Watcher) -> Result<R, E>
189 where
190 F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
191 E: From<T::Error>,
192 {
193 self.before_write(watcher)?;
194 let result = T::write(&mut self.connection, closure);
195 if let Err(e) = self.after_write(watcher) {
196 tracing::error!("Failed to publish updates to watcher: {e:?}");
197 }
198 result
199 }
200
201 fn before_write(&mut self, watcher: &Watcher) -> Result<(), <T as Driver>::Error> {
202 use sqlite_watcher::statement::Statement;
203 if let Some(stmt) = self.state.sync_tables(watcher) {
204 stmt.execute_mut(&mut self.connection)?;
205 }
206 Ok(())
207 }
208
209 fn after_write(&mut self, watcher: &Watcher) -> Result<(), <T as Driver>::Error> {
210 use sqlite_watcher::statement::Statement;
211 self.state
212 .publish_changes(watcher)
213 .execute_mut(&mut self.connection)?;
214 Ok(())
215 }
216}
217
218#[cfg(not(feature = "watcher"))]
219impl<T> WatchedConnection<T>
220where
221 T: Driver,
222{
223 fn new(connection: T::Connection) -> Self {
224 Self { connection }
225 }
226
227 fn transaction<F, R, E>(&mut self, closure: F) -> Result<R, E>
228 where
229 F: FnOnce(&mut T::Transaction<'_>) -> Result<R, E>,
230 E: From<T::Error>,
231 {
232 T::write(&mut self.connection, closure)
233 }
234}