use crate::connection::handle::ConnectionHandle;
use crate::connection::LogSettings;
use crate::connection::{ConnectionState, Statements};
use crate::error::Error;
use crate::{SqliteConnectOptions, SqliteError};
use libsqlite3_sys::{
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
};
use percent_encoding::NON_ALPHANUMERIC;
use sqlx_core::IndexMap;
use std::collections::BTreeMap;
use std::ffi::{c_void, CStr, CString};
use std::io;
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Copy, Clone)]
enum SqliteLoadExtensionMode {
Enable,
DisableAll,
}
impl SqliteLoadExtensionMode {
fn to_int(self) -> c_int {
match self {
SqliteLoadExtensionMode::Enable => 1,
SqliteLoadExtensionMode::DisableAll => 0,
}
}
}
pub struct EstablishParams {
filename: CString,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
log_settings: LogSettings,
extensions: IndexMap<CString, Option<CString>>,
pub(crate) thread_name: String,
pub(crate) command_channel_size: usize,
#[cfg(feature = "regexp")]
register_regexp_function: bool,
}
impl EstablishParams {
pub fn from_options(options: &SqliteConnectOptions) -> Result<Self, Error> {
let mut filename = options
.filename
.to_str()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must be valid UTF-8",
)
})?
.to_owned();
let mut flags = if options.serialized {
SQLITE_OPEN_FULLMUTEX
} else {
SQLITE_OPEN_NOMUTEX
};
flags |= if options.read_only {
SQLITE_OPEN_READONLY
} else if options.create_if_missing {
SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE
} else {
SQLITE_OPEN_READWRITE
};
if options.in_memory {
flags |= SQLITE_OPEN_MEMORY;
}
flags |= if options.shared_cache {
SQLITE_OPEN_SHAREDCACHE
} else {
SQLITE_OPEN_PRIVATECACHE
};
let mut query_params = BTreeMap::new();
if options.immutable {
query_params.insert("immutable", "true");
}
if let Some(vfs) = options.vfs.as_deref() {
query_params.insert("vfs", vfs);
}
if !query_params.is_empty() {
filename = format!(
"file:{}?{}",
percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
serde_urlencoded::to_string(&query_params).unwrap()
);
flags |= libsqlite3_sys::SQLITE_OPEN_URI;
}
let filename = CString::new(filename).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must not contain nul bytes",
)
})?;
let extensions = options
.extensions
.iter()
.map(|(name, entry)| {
let entry = entry
.as_ref()
.map(|e| {
CString::new(e.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension entrypoint names passed to SQLite must not contain nul bytes"
)
})
})
.transpose()?;
Ok((
CString::new(name.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension names passed to SQLite must not contain nul bytes",
)
})?,
entry,
))
})
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
let thread_id = THREAD_ID.fetch_add(1, Ordering::AcqRel);
Ok(Self {
filename,
open_flags: flags,
busy_timeout: options.busy_timeout,
statement_cache_capacity: options.statement_cache_capacity,
log_settings: options.log_settings.clone(),
extensions,
thread_name: (options.thread_name)(thread_id as u64),
command_channel_size: options.command_channel_size,
#[cfg(feature = "regexp")]
register_regexp_function: options.register_regexp_function,
})
}
unsafe fn sqlite3_set_load_extension(
db: *mut sqlite3,
mode: SqliteLoadExtensionMode,
) -> Result<(), Error> {
let status = sqlite3_db_config(
db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
mode.to_int(),
null::<i32>(),
);
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(db))));
}
Ok(())
}
pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut handle = null_mut();
let mut status = unsafe {
sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null())
};
if handle.is_null() {
return Err(Error::Io(io::Error::new(
io::ErrorKind::OutOfMemory,
"SQLite is unable to allocate memory to hold the sqlite3 object",
)));
}
let handle = unsafe { ConnectionHandle::new(handle) };
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
}
unsafe {
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}
if !self.extensions.is_empty() {
unsafe {
Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
}
for ext in self.extensions.iter() {
let mut error = null_mut();
status = unsafe {
sqlite3_load_extension(
handle.as_ptr(),
ext.0.as_ptr(),
ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
addr_of_mut!(error),
)
};
if status != SQLITE_OK {
let err_msg = if !error.is_null() {
unsafe {
let e = CStr::from_ptr(error).into();
sqlite3_free(error as *mut c_void);
e
}
} else {
CString::new("Unknown error when loading extension")
.expect("text should be representable as a CString")
};
return Err(Error::Database(Box::new(SqliteError::extension(
handle.as_ptr(),
&err_msg,
))));
}
} unsafe {
Self::sqlite3_set_load_extension(
handle.as_ptr(),
SqliteLoadExtensionMode::DisableAll,
)?;
}
}
#[cfg(feature = "regexp")]
if self.register_regexp_function {
let status = crate::regexp::register(handle.as_ptr());
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
}
}
let ms = i32::try_from(self.busy_timeout.as_millis())
.expect("Given busy timeout value is too big.");
status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
}
Ok(ConnectionState {
handle,
statements: Statements::new(self.statement_cache_capacity),
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
})
}
}