sqlx_sqlite/connection/
establish.rs

1use crate::connection::handle::ConnectionHandle;
2use crate::connection::LogSettings;
3use crate::connection::{ConnectionState, Statements};
4use crate::error::Error;
5use crate::{SqliteConnectOptions, SqliteError};
6use libsqlite3_sys::{
7    sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
8    sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
9    SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
10    SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
11    SQLITE_OPEN_URI,
12};
13use percent_encoding::NON_ALPHANUMERIC;
14use sqlx_core::IndexMap;
15use std::collections::BTreeMap;
16use std::ffi::{c_void, CStr, CString};
17use std::io;
18use std::os::raw::c_int;
19use std::ptr::{addr_of_mut, null, null_mut};
20use std::sync::atomic::{AtomicUsize, Ordering};
21use std::time::Duration;
22
23// This was originally `AtomicU64` but that's not supported on MIPS (or PowerPC):
24// https://github.com/launchbadge/sqlx/issues/2859
25// https://doc.rust-lang.org/stable/std/sync/atomic/index.html#portability
26static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
27
28#[derive(Copy, Clone)]
29enum SqliteLoadExtensionMode {
30    /// Enables only the C-API, leaving the SQL function disabled.
31    Enable,
32    /// Disables both the C-API and the SQL function.
33    DisableAll,
34}
35
36impl SqliteLoadExtensionMode {
37    fn to_int(self) -> c_int {
38        match self {
39            SqliteLoadExtensionMode::Enable => 1,
40            SqliteLoadExtensionMode::DisableAll => 0,
41        }
42    }
43}
44
45pub struct EstablishParams {
46    filename: CString,
47    open_flags: i32,
48    busy_timeout: Duration,
49    statement_cache_capacity: usize,
50    log_settings: LogSettings,
51    extensions: IndexMap<CString, Option<CString>>,
52    pub(crate) thread_name: String,
53    pub(crate) command_channel_size: usize,
54    #[cfg(feature = "regexp")]
55    register_regexp_function: bool,
56}
57
58impl EstablishParams {
59    pub fn from_options(options: &SqliteConnectOptions) -> Result<Self, Error> {
60        let mut filename = options
61            .filename
62            .to_str()
63            .ok_or_else(|| {
64                io::Error::new(
65                    io::ErrorKind::InvalidData,
66                    "filename passed to SQLite must be valid UTF-8",
67                )
68            })?
69            .to_owned();
70
71        // Set common flags we expect to have in sqlite
72        let mut flags = SQLITE_OPEN_URI;
73
74        // By default, we connect to an in-memory database.
75        // [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it
76        // cannot satisfy our wish for a thread-safe, lock-free connection object
77
78        flags |= if options.serialized {
79            SQLITE_OPEN_FULLMUTEX
80        } else {
81            SQLITE_OPEN_NOMUTEX
82        };
83
84        flags |= if options.read_only {
85            SQLITE_OPEN_READONLY
86        } else if options.create_if_missing {
87            SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE
88        } else {
89            SQLITE_OPEN_READWRITE
90        };
91
92        if options.in_memory {
93            flags |= SQLITE_OPEN_MEMORY;
94        }
95
96        flags |= if options.shared_cache {
97            SQLITE_OPEN_SHAREDCACHE
98        } else {
99            SQLITE_OPEN_PRIVATECACHE
100        };
101
102        let mut query_params = BTreeMap::new();
103
104        if options.immutable {
105            query_params.insert("immutable", "true");
106        }
107
108        if let Some(vfs) = options.vfs.as_deref() {
109            query_params.insert("vfs", vfs);
110        }
111
112        if !query_params.is_empty() {
113            filename = format!(
114                "file:{}?{}",
115                percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
116                serde_urlencoded::to_string(&query_params).unwrap()
117            );
118        }
119
120        let filename = CString::new(filename).map_err(|_| {
121            io::Error::new(
122                io::ErrorKind::InvalidData,
123                "filename passed to SQLite must not contain nul bytes",
124            )
125        })?;
126
127        let extensions = options
128            .extensions
129            .iter()
130            .map(|(name, entry)| {
131                let entry = entry
132                    .as_ref()
133                    .map(|e| {
134                        CString::new(e.as_bytes()).map_err(|_| {
135                            io::Error::new(
136                                io::ErrorKind::InvalidData,
137                                "extension entrypoint names passed to SQLite must not contain nul bytes"
138                            )
139                        })
140                    })
141                    .transpose()?;
142                Ok((
143                    CString::new(name.as_bytes()).map_err(|_| {
144                        io::Error::new(
145                            io::ErrorKind::InvalidData,
146                            "extension names passed to SQLite must not contain nul bytes",
147                        )
148                    })?,
149                    entry,
150                ))
151            })
152            .collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
153
154        let thread_id = THREAD_ID.fetch_add(1, Ordering::AcqRel);
155
156        Ok(Self {
157            filename,
158            open_flags: flags,
159            busy_timeout: options.busy_timeout,
160            statement_cache_capacity: options.statement_cache_capacity,
161            log_settings: options.log_settings.clone(),
162            extensions,
163            thread_name: (options.thread_name)(thread_id as u64),
164            command_channel_size: options.command_channel_size,
165            #[cfg(feature = "regexp")]
166            register_regexp_function: options.register_regexp_function,
167        })
168    }
169
170    // Enable extension loading via the db_config function, as recommended by the docs rather
171    // than the more obvious `sqlite3_enable_load_extension`
172    // https://www.sqlite.org/c3ref/db_config.html
173    // https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
174    unsafe fn sqlite3_set_load_extension(
175        db: *mut sqlite3,
176        mode: SqliteLoadExtensionMode,
177    ) -> Result<(), Error> {
178        let status = sqlite3_db_config(
179            db,
180            SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
181            mode.to_int(),
182            null::<i32>(),
183        );
184
185        if status != SQLITE_OK {
186            return Err(Error::Database(Box::new(SqliteError::new(db))));
187        }
188
189        Ok(())
190    }
191
192    pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
193        let mut handle = null_mut();
194
195        // <https://www.sqlite.org/c3ref/open.html>
196        let mut status = unsafe {
197            sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null())
198        };
199
200        if handle.is_null() {
201            // Failed to allocate memory
202            return Err(Error::Io(io::Error::new(
203                io::ErrorKind::OutOfMemory,
204                "SQLite is unable to allocate memory to hold the sqlite3 object",
205            )));
206        }
207
208        // SAFE: tested for NULL just above
209        // This allows any returns below to close this handle with RAII
210        let mut handle = unsafe { ConnectionHandle::new(handle) };
211
212        if status != SQLITE_OK {
213            return Err(Error::Database(Box::new(handle.expect_error())));
214        }
215
216        // Enable extended result codes
217        // https://www.sqlite.org/c3ref/extended_result_codes.html
218        unsafe {
219            // NOTE: ignore the failure here
220            sqlite3_extended_result_codes(handle.as_ptr(), 1);
221        }
222
223        if !self.extensions.is_empty() {
224            // Enable loading extensions
225            unsafe {
226                Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
227            }
228
229            for ext in self.extensions.iter() {
230                // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer
231                // rather than by calling `sqlite3_errmsg`
232                let mut error_msg = null_mut();
233                status = unsafe {
234                    sqlite3_load_extension(
235                        handle.as_ptr(),
236                        ext.0.as_ptr(),
237                        ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
238                        addr_of_mut!(error_msg),
239                    )
240                };
241
242                if status != SQLITE_OK {
243                    let mut e = handle.expect_error();
244
245                    // SAFETY: We become responsible for any memory allocation at `&error`, so test
246                    // for null and take an RAII version for returns
247                    if !error_msg.is_null() {
248                        e = e.with_message(unsafe {
249                            let msg = CStr::from_ptr(error_msg).to_string_lossy().into();
250                            sqlite3_free(error_msg as *mut c_void);
251                            msg
252                        });
253                    }
254                    return Err(Error::Database(Box::new(e)));
255                }
256            } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION
257              // on by disabling the flag again once we've loaded all the requested modules.
258              // Fail-fast (via `?`) if disabling the extension loader didn't work for some reason,
259              // avoids an unexpected state going undetected.
260            unsafe {
261                Self::sqlite3_set_load_extension(
262                    handle.as_ptr(),
263                    SqliteLoadExtensionMode::DisableAll,
264                )?;
265            }
266        }
267
268        #[cfg(feature = "regexp")]
269        if self.register_regexp_function {
270            // configure a `regexp` function for sqlite, it does not come with one by default
271            let status = crate::regexp::register(handle.as_ptr());
272            if status != SQLITE_OK {
273                return Err(Error::Database(Box::new(handle.expect_error())));
274            }
275        }
276
277        // Configure a busy timeout
278        // This causes SQLite to automatically sleep in increasing intervals until the time
279        // when there is something locked during [sqlite3_step].
280        //
281        // We also need to convert the u128 value to i32, checking we're not overflowing.
282        let ms = i32::try_from(self.busy_timeout.as_millis())
283            .expect("Given busy timeout value is too big.");
284
285        status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };
286
287        if status != SQLITE_OK {
288            return Err(Error::Database(Box::new(handle.expect_error())));
289        }
290
291        Ok(ConnectionState {
292            handle,
293            statements: Statements::new(self.statement_cache_capacity),
294            log_settings: self.log_settings.clone(),
295            progress_handler_callback: None,
296            update_hook_callback: None,
297            #[cfg(feature = "preupdate-hook")]
298            preupdate_hook_callback: None,
299            commit_hook_callback: None,
300            rollback_hook_callback: None,
301        })
302    }
303}