sqlite_compressions/
common_diff.rs1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3#[cfg(feature = "trace")]
4use log::trace;
5use rusqlite::functions::{Context, FunctionFlags};
6use rusqlite::types::{Type, ValueRef};
7use rusqlite::Connection;
8use rusqlite::Error::InvalidFunctionParameterType;
9
10#[cfg(not(feature = "trace"))]
11macro_rules! trace {
12 ($($arg:tt)*) => {};
13}
14
15use crate::rusqlite::Result;
16
17pub trait Differ {
18 fn diff_name() -> &'static str;
19 fn patch_name() -> &'static str;
20 fn diff(source: &[u8], target: &[u8]) -> Result<Vec<u8>>;
21 fn patch(source: &[u8], patch: &[u8]) -> Result<Vec<u8>>;
22}
23
24pub(crate) fn register_differ<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
25 conn: &Connection,
26) -> Result<()> {
27 macro_rules! flags {
29 () => {
30 FunctionFlags::SQLITE_UTF8
31 | FunctionFlags::SQLITE_DETERMINISTIC
32 | FunctionFlags::SQLITE_DIRECTONLY
33 };
34 }
35
36 trace!("Registering function {}", T::diff_name());
37 conn.create_scalar_function(T::diff_name(), 2, flags!(), diff_fn::<T>)?;
38
39 trace!("Registering function {}", T::patch_name());
40 conn.create_scalar_function(T::patch_name(), 2, flags!(), patch_fn::<T>)
41}
42
43fn diff_fn<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
44 ctx: &Context,
45) -> Result<Option<Vec<u8>>> {
46 let Some(source) = get_bytes(ctx, 0)? else {
47 return Ok(None);
48 };
49 let Some(target) = get_bytes(ctx, 1)? else {
50 return Ok(None);
51 };
52 Ok(Some(T::diff(source, target)?))
53}
54
55fn patch_fn<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
56 ctx: &Context,
57) -> Result<Option<Vec<u8>>> {
58 let Some(source) = get_bytes(ctx, 0)? else {
59 return Ok(None);
60 };
61 let Some(patch) = get_bytes(ctx, 1)? else {
62 return Ok(None);
63 };
64 Ok(Some(T::patch(source, patch)?))
65}
66
67pub(crate) fn get_bytes<'a>(ctx: &'a Context, index: usize) -> Result<Option<&'a [u8]>> {
68 match ctx.get_raw(index) {
69 ValueRef::Blob(val) | ValueRef::Text(val) => Ok(Some(val)),
70 ValueRef::Null => Ok(None),
71 ValueRef::Integer(_) => Err(InvalidFunctionParameterType(index, Type::Integer)),
72 ValueRef::Real(_) => Err(InvalidFunctionParameterType(index, Type::Real)),
73 }
74}