sqlite_compressions/
common.rs

1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3#[cfg(feature = "trace")]
4use log::trace;
5use rusqlite::functions::Context;
6use rusqlite::types::{Type, ValueRef};
7use rusqlite::Error::{InvalidFunctionParameterType, InvalidParameterCount};
8
9use crate::rusqlite::functions::FunctionFlags;
10use crate::rusqlite::{Connection, Result};
11
12#[cfg(not(feature = "trace"))]
13macro_rules! trace {
14    ($($arg:tt)*) => {};
15}
16
17pub trait Encoder {
18    fn enc_name() -> &'static str;
19    fn dec_name() -> &'static str;
20    fn test_name() -> &'static str;
21    fn encode(data: &[u8], quality: Option<u32>) -> Result<Vec<u8>>;
22    fn decode(data: &[u8]) -> Result<Vec<u8>>;
23    fn test(data: &[u8]) -> bool;
24}
25
26pub(crate) fn register_compression<T: Encoder + UnwindSafe + RefUnwindSafe + 'static>(
27    conn: &Connection,
28) -> Result<()> {
29    // FunctionFlags derive Copy trait only in v0.31+, but we support v0.30+
30    macro_rules! flags {
31        () => {
32            FunctionFlags::SQLITE_UTF8
33                | FunctionFlags::SQLITE_DETERMINISTIC
34                | FunctionFlags::SQLITE_DIRECTONLY
35        };
36    }
37
38    trace!("Registering function {}", T::enc_name());
39    conn.create_scalar_function(T::enc_name(), -1, flags!(), encoder_fn::<T>)?;
40
41    trace!("Registering function {}", T::dec_name());
42    conn.create_scalar_function(T::dec_name(), -1, flags!(), decoder_fn::<T>)?;
43
44    trace!("Registering function {}", T::test_name());
45    conn.create_scalar_function(T::test_name(), -1, flags!(), testing_fn::<T>)
46}
47
48fn encoder_fn<T: Encoder + UnwindSafe + RefUnwindSafe + 'static>(
49    ctx: &Context,
50) -> Result<Option<Vec<u8>>> {
51    let param_count = ctx.len();
52    if param_count == 0 || param_count > 2 {
53        return Err(InvalidParameterCount(param_count, 1));
54    }
55    let quality = if param_count == 2 {
56        Some(ctx.get::<u32>(1)?)
57    } else {
58        None
59    };
60
61    let value = ctx.get_raw(0);
62    match value {
63        ValueRef::Blob(val) => {
64            trace!("{}: encoding blob {val:?}", T::enc_name());
65            Ok(Some(T::encode(val, quality)?))
66        }
67        ValueRef::Text(val) => {
68            trace!("{}: encoding text {val:?}", T::enc_name());
69            Ok(Some(T::encode(val, quality)?))
70        }
71        ValueRef::Null => {
72            trace!("{}: ignoring NULL", T::enc_name());
73            Ok(None)
74        }
75        #[allow(unused_variables)]
76        ValueRef::Integer(val) => {
77            trace!("{}: unsupported Integer {val:?}", T::enc_name());
78            Err(InvalidFunctionParameterType(0, Type::Integer))
79        }
80        #[allow(unused_variables)]
81        ValueRef::Real(val) => {
82            trace!("{}: unsupported Real {val:?}", T::enc_name());
83            Err(InvalidFunctionParameterType(0, Type::Real))
84        }
85    }
86}
87
88fn decoder_fn<T: Encoder + UnwindSafe + RefUnwindSafe + 'static>(
89    ctx: &Context,
90) -> Result<Option<Vec<u8>>> {
91    let param_count = ctx.len();
92    if param_count != 1 {
93        return Err(InvalidParameterCount(param_count, 1));
94    }
95
96    let value = ctx.get_raw(0);
97    match value {
98        ValueRef::Blob(val) => {
99            trace!("{}: decoding blob {val:?}", T::dec_name());
100            Ok(Some(T::decode(val)?))
101        }
102        ValueRef::Null => {
103            trace!("{}: ignoring NULL", T::dec_name());
104            Ok(None)
105        }
106        #[allow(unused_variables)]
107        ValueRef::Text(val) => {
108            trace!("{}: unsupported Text {val:?}", T::dec_name());
109            Err(InvalidFunctionParameterType(0, Type::Text))
110        }
111        #[allow(unused_variables)]
112        ValueRef::Integer(val) => {
113            trace!("{}: unsupported Integer {val:?}", T::dec_name());
114            Err(InvalidFunctionParameterType(0, Type::Integer))
115        }
116        #[allow(unused_variables)]
117        ValueRef::Real(val) => {
118            trace!("{}: unsupported Real {val:?}", T::dec_name());
119            Err(InvalidFunctionParameterType(0, Type::Real))
120        }
121    }
122}
123
124fn testing_fn<T: Encoder + UnwindSafe + RefUnwindSafe + 'static>(
125    ctx: &Context,
126) -> Result<Option<bool>> {
127    let param_count = ctx.len();
128    if param_count != 1 {
129        return Err(InvalidParameterCount(param_count, 1));
130    }
131
132    let value = ctx.get_raw(0);
133    match value {
134        ValueRef::Blob(val) => {
135            trace!("{}: testing encoded blob {val:?}", T::test_name());
136            Ok(Some(T::test(val)))
137        }
138        ValueRef::Null => {
139            trace!("{}: ignoring NULL", T::test_name());
140            Ok(None)
141        }
142        #[allow(unused_variables)]
143        ValueRef::Text(val) => {
144            trace!("{}: unsupported Text {val:?}", T::test_name());
145            Err(InvalidFunctionParameterType(0, Type::Text))
146        }
147        #[allow(unused_variables)]
148        ValueRef::Integer(val) => {
149            trace!("{}: unsupported Integer {val:?}", T::test_name());
150            Err(InvalidFunctionParameterType(0, Type::Integer))
151        }
152        #[allow(unused_variables)]
153        ValueRef::Real(val) => {
154            trace!("{}: unsupported Real {val:?}", T::test_name());
155            Err(InvalidFunctionParameterType(0, Type::Real))
156        }
157    }
158}