sqlx_build_trust_sqlite/connection/
mod.rs1use futures_core::future::BoxFuture;
2use futures_intrusive::sync::MutexGuard;
3use futures_util::future;
4use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
5use sqlx_core::common::StatementCache;
6use sqlx_core::error::Error;
7use sqlx_core::transaction::Transaction;
8use std::cmp::Ordering;
9use std::fmt::{self, Debug, Formatter};
10use std::os::raw::{c_int, c_void};
11use std::panic::catch_unwind;
12use std::ptr::NonNull;
13
14use crate::connection::establish::EstablishParams;
15use crate::connection::worker::ConnectionWorker;
16use crate::options::OptimizeOnClose;
17use crate::statement::VirtualStatement;
18use crate::{Sqlite, SqliteConnectOptions};
19use sqlx_core::executor::Executor;
20use std::fmt::Write;
21
22pub(crate) use sqlx_core::connection::*;
23
24pub(crate) use handle::{ConnectionHandle, ConnectionHandleRaw};
25
26pub(crate) mod collation;
27pub(crate) mod describe;
28pub(crate) mod establish;
29pub(crate) mod execute;
30mod executor;
31mod explain;
32mod handle;
33mod intmap;
34
35mod worker;
36
37pub struct SqliteConnection {
48 optimize_on_close: OptimizeOnClose,
49 pub(crate) worker: ConnectionWorker,
50 pub(crate) row_channel_size: usize,
51}
52
53pub struct LockedSqliteHandle<'a> {
54 pub(crate) guard: MutexGuard<'a, ConnectionState>,
55}
56
57pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
59unsafe impl Send for Handler {}
60
61pub(crate) struct ConnectionState {
62 pub(crate) handle: ConnectionHandle,
63
64 pub(crate) transaction_depth: usize,
66
67 pub(crate) statements: Statements,
68
69 log_settings: LogSettings,
70
71 progress_handler_callback: Option<Handler>,
74}
75
76impl ConnectionState {
77 pub(crate) fn remove_progress_handler(&mut self) {
79 if let Some(mut handler) = self.progress_handler_callback.take() {
80 unsafe {
81 sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
82 let _ = { Box::from_raw(handler.0.as_mut()) };
83 }
84 }
85 }
86}
87
88pub(crate) struct Statements {
89 cached: StatementCache<VirtualStatement>,
91 temp: Option<VirtualStatement>,
93}
94
95impl SqliteConnection {
96 pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<Self, Error> {
97 let params = EstablishParams::from_options(options)?;
98 let worker = ConnectionWorker::establish(params).await?;
99 Ok(Self {
100 optimize_on_close: options.optimize_on_close.clone(),
101 worker,
102 row_channel_size: options.row_channel_size,
103 })
104 }
105
106 pub async fn lock_handle(&mut self) -> Result<LockedSqliteHandle<'_>, Error> {
111 let guard = self.worker.unlock_db().await?;
112
113 Ok(LockedSqliteHandle { guard })
114 }
115}
116
117impl Debug for SqliteConnection {
118 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
119 f.debug_struct("SqliteConnection")
120 .field("row_channel_size", &self.row_channel_size)
121 .field("cached_statements_size", &self.cached_statements_size())
122 .finish()
123 }
124}
125
126impl Connection for SqliteConnection {
127 type Database = Sqlite;
128
129 type Options = SqliteConnectOptions;
130
131 fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
132 Box::pin(async move {
133 if let OptimizeOnClose::Enabled { analysis_limit } = self.optimize_on_close {
134 let mut pragma_string = String::new();
135 if let Some(limit) = analysis_limit {
136 write!(pragma_string, "PRAGMA analysis_limit = {limit}; ").ok();
137 }
138 pragma_string.push_str("PRAGMA optimize;");
139 self.execute(&*pragma_string).await?;
140 }
141 let shutdown = self.worker.shutdown();
142 drop(self);
145 shutdown.await
147 })
148 }
149
150 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
151 Box::pin(async move {
152 drop(self);
153 Ok(())
154 })
155 }
156
157 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
159 Box::pin(self.worker.ping())
160 }
161
162 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
163 where
164 Self: Sized,
165 {
166 Transaction::begin(self)
167 }
168
169 fn cached_statements_size(&self) -> usize {
170 self.worker
171 .shared
172 .cached_statements_size
173 .load(std::sync::atomic::Ordering::Acquire)
174 }
175
176 fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
177 Box::pin(async move {
178 self.worker.clear_cache().await?;
179 Ok(())
180 })
181 }
182
183 #[inline]
184 fn shrink_buffers(&mut self) {
185 }
187
188 #[doc(hidden)]
189 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
190 Box::pin(future::ok(()))
195 }
196
197 #[doc(hidden)]
198 fn should_flush(&self) -> bool {
199 false
200 }
201}
202
203extern "C" fn progress_callback<F>(callback: *mut c_void) -> c_int
206where
207 F: FnMut() -> bool,
208{
209 unsafe {
210 let r = catch_unwind(|| {
211 let callback: *mut F = callback.cast::<F>();
212 (*callback)()
213 });
214 c_int::from(!r.unwrap_or_default())
215 }
216}
217
218impl LockedSqliteHandle<'_> {
219 pub fn as_raw_handle(&mut self) -> NonNull<sqlite3> {
234 self.guard.handle.as_non_null_ptr()
235 }
236
237 pub fn create_collation(
241 &mut self,
242 name: &str,
243 compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static,
244 ) -> Result<(), Error> {
245 collation::create_collation(&mut self.guard.handle, name, compare)
246 }
247
248 pub fn set_progress_handler<F>(&mut self, num_ops: i32, callback: F)
262 where
263 F: FnMut() -> bool + Send + 'static,
264 {
265 unsafe {
266 let callback_boxed = Box::new(callback);
267 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
269 let handler = callback.as_ptr() as *mut _;
270 self.guard.remove_progress_handler();
271 self.guard.progress_handler_callback = Some(Handler(callback));
272
273 sqlite3_progress_handler(
274 self.as_raw_handle().as_mut(),
275 num_ops,
276 Some(progress_callback::<F>),
277 handler,
278 );
279 }
280 }
281
282 pub fn remove_progress_handler(&mut self) {
284 self.guard.remove_progress_handler();
285 }
286}
287
288impl Drop for ConnectionState {
289 fn drop(&mut self) {
290 self.statements.clear();
292 self.remove_progress_handler();
293 }
294}
295
296impl Statements {
297 fn new(capacity: usize) -> Self {
298 Statements {
299 cached: StatementCache::new(capacity),
300 temp: None,
301 }
302 }
303
304 fn get(&mut self, query: &str, persistent: bool) -> Result<&mut VirtualStatement, Error> {
305 if !persistent || !self.cached.is_enabled() {
306 return Ok(self.temp.insert(VirtualStatement::new(query, false)?));
307 }
308
309 let exists = self.cached.contains_key(query);
310
311 if !exists {
312 let statement = VirtualStatement::new(query, true)?;
313 self.cached.insert(query, statement);
314 }
315
316 let statement = self.cached.get_mut(query).unwrap();
317
318 if exists {
319 statement.reset()?;
321 }
322
323 Ok(statement)
324 }
325
326 fn len(&self) -> usize {
327 self.cached.len()
328 }
329
330 fn clear(&mut self) {
331 self.cached.clear();
332 self.temp = None;
333 }
334}