sqlite_compressions/
common_diff.rs

1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3#[cfg(feature = "trace")]
4use log::trace;
5use rusqlite::functions::{Context, FunctionFlags};
6use rusqlite::types::{Type, ValueRef};
7use rusqlite::Connection;
8use rusqlite::Error::InvalidFunctionParameterType;
9
10#[cfg(not(feature = "trace"))]
11macro_rules! trace {
12    ($($arg:tt)*) => {};
13}
14
15use crate::rusqlite::Result;
16
17pub trait Differ {
18    fn diff_name() -> &'static str;
19    fn patch_name() -> &'static str;
20    fn diff(source: &[u8], target: &[u8]) -> Result<Vec<u8>>;
21    fn patch(source: &[u8], patch: &[u8]) -> Result<Vec<u8>>;
22}
23
24pub(crate) fn register_differ<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
25    conn: &Connection,
26) -> Result<()> {
27    // FunctionFlags derive Copy trait only in v0.31+, but we support v0.30+
28    macro_rules! flags {
29        () => {
30            FunctionFlags::SQLITE_UTF8
31                | FunctionFlags::SQLITE_DETERMINISTIC
32                | FunctionFlags::SQLITE_DIRECTONLY
33        };
34    }
35
36    trace!("Registering function {}", T::diff_name());
37    conn.create_scalar_function(T::diff_name(), 2, flags!(), diff_fn::<T>)?;
38
39    trace!("Registering function {}", T::patch_name());
40    conn.create_scalar_function(T::patch_name(), 2, flags!(), patch_fn::<T>)
41}
42
43fn diff_fn<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
44    ctx: &Context,
45) -> Result<Option<Vec<u8>>> {
46    let Some(source) = get_bytes(ctx, 0)? else {
47        return Ok(None);
48    };
49    let Some(target) = get_bytes(ctx, 1)? else {
50        return Ok(None);
51    };
52    Ok(Some(T::diff(source, target)?))
53}
54
55fn patch_fn<T: Differ + UnwindSafe + RefUnwindSafe + 'static>(
56    ctx: &Context,
57) -> Result<Option<Vec<u8>>> {
58    let Some(source) = get_bytes(ctx, 0)? else {
59        return Ok(None);
60    };
61    let Some(patch) = get_bytes(ctx, 1)? else {
62        return Ok(None);
63    };
64    Ok(Some(T::patch(source, patch)?))
65}
66
67pub(crate) fn get_bytes<'a>(ctx: &'a Context, index: usize) -> Result<Option<&'a [u8]>> {
68    match ctx.get_raw(index) {
69        ValueRef::Blob(val) | ValueRef::Text(val) => Ok(Some(val)),
70        ValueRef::Null => Ok(None),
71        ValueRef::Integer(_) => Err(InvalidFunctionParameterType(index, Type::Integer)),
72        ValueRef::Real(_) => Err(InvalidFunctionParameterType(index, Type::Real)),
73    }
74}