1use crate::connection::handle::ConnectionHandle;
2use crate::connection::LogSettings;
3use crate::connection::{ConnectionState, Statements};
4use crate::error::Error;
5use crate::{SqliteConnectOptions, SqliteError};
6use libsqlite3_sys::{
7 sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
8 sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
9 SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
10 SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
11 SQLITE_OPEN_URI,
12};
13use percent_encoding::NON_ALPHANUMERIC;
14use sqlx_core::IndexMap;
15use std::collections::BTreeMap;
16use std::ffi::{c_void, CStr, CString};
17use std::io;
18use std::os::raw::c_int;
19use std::ptr::{addr_of_mut, null, null_mut};
20use std::sync::atomic::{AtomicUsize, Ordering};
21use std::time::Duration;
22
23static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
27
28#[derive(Copy, Clone)]
29enum SqliteLoadExtensionMode {
30 Enable,
32 DisableAll,
34}
35
36impl SqliteLoadExtensionMode {
37 fn to_int(self) -> c_int {
38 match self {
39 SqliteLoadExtensionMode::Enable => 1,
40 SqliteLoadExtensionMode::DisableAll => 0,
41 }
42 }
43}
44
45pub struct EstablishParams {
46 filename: CString,
47 open_flags: i32,
48 busy_timeout: Duration,
49 statement_cache_capacity: usize,
50 log_settings: LogSettings,
51 extensions: IndexMap<CString, Option<CString>>,
52 pub(crate) thread_name: String,
53 pub(crate) command_channel_size: usize,
54 #[cfg(feature = "regexp")]
55 register_regexp_function: bool,
56}
57
58impl EstablishParams {
59 pub fn from_options(options: &SqliteConnectOptions) -> Result<Self, Error> {
60 let mut filename = options
61 .filename
62 .to_str()
63 .ok_or_else(|| {
64 io::Error::new(
65 io::ErrorKind::InvalidData,
66 "filename passed to SQLite must be valid UTF-8",
67 )
68 })?
69 .to_owned();
70
71 let mut flags = SQLITE_OPEN_URI;
73
74 flags |= if options.serialized {
79 SQLITE_OPEN_FULLMUTEX
80 } else {
81 SQLITE_OPEN_NOMUTEX
82 };
83
84 flags |= if options.read_only {
85 SQLITE_OPEN_READONLY
86 } else if options.create_if_missing {
87 SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE
88 } else {
89 SQLITE_OPEN_READWRITE
90 };
91
92 if options.in_memory {
93 flags |= SQLITE_OPEN_MEMORY;
94 }
95
96 flags |= if options.shared_cache {
97 SQLITE_OPEN_SHAREDCACHE
98 } else {
99 SQLITE_OPEN_PRIVATECACHE
100 };
101
102 let mut query_params = BTreeMap::new();
103
104 if options.immutable {
105 query_params.insert("immutable", "true");
106 }
107
108 if let Some(vfs) = options.vfs.as_deref() {
109 query_params.insert("vfs", vfs);
110 }
111
112 if !query_params.is_empty() {
113 filename = format!(
114 "file:{}?{}",
115 percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
116 serde_urlencoded::to_string(&query_params).unwrap()
117 );
118 }
119
120 let filename = CString::new(filename).map_err(|_| {
121 io::Error::new(
122 io::ErrorKind::InvalidData,
123 "filename passed to SQLite must not contain nul bytes",
124 )
125 })?;
126
127 let extensions = options
128 .extensions
129 .iter()
130 .map(|(name, entry)| {
131 let entry = entry
132 .as_ref()
133 .map(|e| {
134 CString::new(e.as_bytes()).map_err(|_| {
135 io::Error::new(
136 io::ErrorKind::InvalidData,
137 "extension entrypoint names passed to SQLite must not contain nul bytes"
138 )
139 })
140 })
141 .transpose()?;
142 Ok((
143 CString::new(name.as_bytes()).map_err(|_| {
144 io::Error::new(
145 io::ErrorKind::InvalidData,
146 "extension names passed to SQLite must not contain nul bytes",
147 )
148 })?,
149 entry,
150 ))
151 })
152 .collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
153
154 let thread_id = THREAD_ID.fetch_add(1, Ordering::AcqRel);
155
156 Ok(Self {
157 filename,
158 open_flags: flags,
159 busy_timeout: options.busy_timeout,
160 statement_cache_capacity: options.statement_cache_capacity,
161 log_settings: options.log_settings.clone(),
162 extensions,
163 thread_name: (options.thread_name)(thread_id as u64),
164 command_channel_size: options.command_channel_size,
165 #[cfg(feature = "regexp")]
166 register_regexp_function: options.register_regexp_function,
167 })
168 }
169
170 unsafe fn sqlite3_set_load_extension(
175 db: *mut sqlite3,
176 mode: SqliteLoadExtensionMode,
177 ) -> Result<(), Error> {
178 let status = sqlite3_db_config(
179 db,
180 SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
181 mode.to_int(),
182 null::<i32>(),
183 );
184
185 if status != SQLITE_OK {
186 return Err(Error::Database(Box::new(SqliteError::new(db))));
187 }
188
189 Ok(())
190 }
191
192 pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
193 let mut handle = null_mut();
194
195 let mut status = unsafe {
197 sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null())
198 };
199
200 if handle.is_null() {
201 return Err(Error::Io(io::Error::new(
203 io::ErrorKind::OutOfMemory,
204 "SQLite is unable to allocate memory to hold the sqlite3 object",
205 )));
206 }
207
208 let mut handle = unsafe { ConnectionHandle::new(handle) };
211
212 if status != SQLITE_OK {
213 return Err(Error::Database(Box::new(handle.expect_error())));
214 }
215
216 unsafe {
219 sqlite3_extended_result_codes(handle.as_ptr(), 1);
221 }
222
223 if !self.extensions.is_empty() {
224 unsafe {
226 Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
227 }
228
229 for ext in self.extensions.iter() {
230 let mut error_msg = null_mut();
233 status = unsafe {
234 sqlite3_load_extension(
235 handle.as_ptr(),
236 ext.0.as_ptr(),
237 ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
238 addr_of_mut!(error_msg),
239 )
240 };
241
242 if status != SQLITE_OK {
243 let mut e = handle.expect_error();
244
245 if !error_msg.is_null() {
248 e = e.with_message(unsafe {
249 let msg = CStr::from_ptr(error_msg).to_string_lossy().into();
250 sqlite3_free(error_msg as *mut c_void);
251 msg
252 });
253 }
254 return Err(Error::Database(Box::new(e)));
255 }
256 } unsafe {
261 Self::sqlite3_set_load_extension(
262 handle.as_ptr(),
263 SqliteLoadExtensionMode::DisableAll,
264 )?;
265 }
266 }
267
268 #[cfg(feature = "regexp")]
269 if self.register_regexp_function {
270 let status = crate::regexp::register(handle.as_ptr());
272 if status != SQLITE_OK {
273 return Err(Error::Database(Box::new(handle.expect_error())));
274 }
275 }
276
277 let ms = i32::try_from(self.busy_timeout.as_millis())
283 .expect("Given busy timeout value is too big.");
284
285 status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };
286
287 if status != SQLITE_OK {
288 return Err(Error::Database(Box::new(handle.expect_error())));
289 }
290
291 Ok(ConnectionState {
292 handle,
293 statements: Statements::new(self.statement_cache_capacity),
294 log_settings: self.log_settings.clone(),
295 progress_handler_callback: None,
296 update_hook_callback: None,
297 #[cfg(feature = "preupdate-hook")]
298 preupdate_hook_callback: None,
299 commit_hook_callback: None,
300 rollback_hook_callback: None,
301 })
302 }
303}