1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//! `feature = "trace"` Tracing and profiling functions. Error and warning log.

use std::ffi::{CStr, CString};
use std::mem;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::time::Duration;

use super::ffi;
use crate::error::error_from_sqlite_code;
use crate::{Connection, Result};

/// `feature = "trace"` Set up the process-wide SQLite error logging callback.
///
/// # Safety
///
/// This function is marked unsafe for two reasons:
///
/// * The function is not threadsafe. No other SQLite calls may be made while
///   `config_log` is running, and multiple threads may not call `config_log`
///   simultaneously.
/// * The provided `callback` itself function has two requirements:
///     * It must not invoke any SQLite calls.
///     * It must be threadsafe if SQLite is used in a multithreaded way.
///
/// cf [The Error And Warning Log](http://sqlite.org/errlog.html).
pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> {
    extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) {
        let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() };
        let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) };

        let s = String::from_utf8_lossy(c_slice);
        let _ = catch_unwind(|| callback(err, &s));
    }

    let rc = match callback {
        Some(f) => ffi::sqlite3_config(
            ffi::SQLITE_CONFIG_LOG,
            log_callback as extern "C" fn(_, _, _),
            f as *mut c_void,
        ),
        None => {
            let nullptr: *mut c_void = ptr::null_mut();
            ffi::sqlite3_config(ffi::SQLITE_CONFIG_LOG, nullptr, nullptr)
        }
    };

    if rc == ffi::SQLITE_OK {
        Ok(())
    } else {
        Err(error_from_sqlite_code(rc, None))
    }
}

/// `feature = "trace"` Write a message into the error log established by
/// `config_log`.
pub fn log(err_code: c_int, msg: &str) {
    let msg = CString::new(msg).expect("SQLite log messages cannot contain embedded zeroes");
    unsafe {
        ffi::sqlite3_log(err_code, b"%s\0" as *const _ as *const c_char, msg.as_ptr());
    }
}

impl Connection {
    /// `feature = "trace"` Register or clear a callback function that can be
    /// used for tracing the execution of SQL statements.
    ///
    /// Prepared statement placeholders are replaced/logged with their assigned
    /// values. There can only be a single tracer defined for each database
    /// connection. Setting a new tracer clears the old one.
    pub fn trace(&mut self, trace_fn: Option<fn(&str)>) {
        unsafe extern "C" fn trace_callback(p_arg: *mut c_void, z_sql: *const c_char) {
            let trace_fn: fn(&str) = mem::transmute(p_arg);
            let c_slice = CStr::from_ptr(z_sql).to_bytes();
            let s = String::from_utf8_lossy(c_slice);
            let _ = catch_unwind(|| trace_fn(&s));
        }

        let c = self.db.borrow_mut();
        match trace_fn {
            Some(f) => unsafe {
                ffi::sqlite3_trace(c.db(), Some(trace_callback), f as *mut c_void);
            },
            None => unsafe {
                ffi::sqlite3_trace(c.db(), None, ptr::null_mut());
            },
        }
    }

    /// `feature = "trace"` Register or clear a callback function that can be
    /// used for profiling the execution of SQL statements.
    ///
    /// There can only be a single profiler defined for each database
    /// connection. Setting a new profiler clears the old one.
    pub fn profile(&mut self, profile_fn: Option<fn(&str, Duration)>) {
        unsafe extern "C" fn profile_callback(
            p_arg: *mut c_void,
            z_sql: *const c_char,
            nanoseconds: u64,
        ) {
            let profile_fn: fn(&str, Duration) = mem::transmute(p_arg);
            let c_slice = CStr::from_ptr(z_sql).to_bytes();
            let s = String::from_utf8_lossy(c_slice);
            const NANOS_PER_SEC: u64 = 1_000_000_000;

            let duration = Duration::new(
                nanoseconds / NANOS_PER_SEC,
                (nanoseconds % NANOS_PER_SEC) as u32,
            );
            let _ = catch_unwind(|| profile_fn(&s, duration));
        }

        let c = self.db.borrow_mut();
        match profile_fn {
            Some(f) => unsafe {
                ffi::sqlite3_profile(c.db(), Some(profile_callback), f as *mut c_void)
            },
            None => unsafe { ffi::sqlite3_profile(c.db(), None, ptr::null_mut()) },
        };
    }
}

#[cfg(test)]
mod test {
    use lazy_static::lazy_static;
    use std::sync::Mutex;
    use std::time::Duration;

    use crate::Connection;

    #[test]
    fn test_trace() {
        lazy_static! {
            static ref TRACED_STMTS: Mutex<Vec<String>> = Mutex::new(Vec::new());
        }
        fn tracer(s: &str) {
            let mut traced_stmts = TRACED_STMTS.lock().unwrap();
            traced_stmts.push(s.to_owned());
        }

        let mut db = Connection::open_in_memory().unwrap();
        db.trace(Some(tracer));
        {
            let _ = db.query_row("SELECT ?", &[1i32], |_| Ok(()));
            let _ = db.query_row("SELECT ?", &["hello"], |_| Ok(()));
        }
        db.trace(None);
        {
            let _ = db.query_row("SELECT ?", &[2i32], |_| Ok(()));
            let _ = db.query_row("SELECT ?", &["goodbye"], |_| Ok(()));
        }

        let traced_stmts = TRACED_STMTS.lock().unwrap();
        assert_eq!(traced_stmts.len(), 2);
        assert_eq!(traced_stmts[0], "SELECT 1");
        assert_eq!(traced_stmts[1], "SELECT 'hello'");
    }

    #[test]
    fn test_profile() {
        lazy_static! {
            static ref PROFILED: Mutex<Vec<(String, Duration)>> = Mutex::new(Vec::new());
        }
        fn profiler(s: &str, d: Duration) {
            let mut profiled = PROFILED.lock().unwrap();
            profiled.push((s.to_owned(), d));
        }

        let mut db = Connection::open_in_memory().unwrap();
        db.profile(Some(profiler));
        db.execute_batch("PRAGMA application_id = 1").unwrap();
        db.profile(None);
        db.execute_batch("PRAGMA application_id = 2").unwrap();

        let profiled = PROFILED.lock().unwrap();
        assert_eq!(profiled.len(), 1);
        assert_eq!(profiled[0].0, "PRAGMA application_id = 1");
    }
}