sqlite_hashes/
scalar.rs

1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3use digest::Digest;
4#[cfg(feature = "trace")]
5use log::trace;
6use rusqlite::functions::Context;
7use rusqlite::ToSql;
8
9#[cfg(feature = "aggregate")]
10use crate::aggregate::create_agg_function;
11use crate::rusqlite::functions::FunctionFlags;
12use crate::rusqlite::types::{Type, ValueRef};
13use crate::rusqlite::Error::{InvalidFunctionParameterType, InvalidParameterCount};
14use crate::rusqlite::{Connection, Result};
15use crate::state::HashState;
16
17#[cfg(not(feature = "trace"))]
18macro_rules! trace {
19    ($($arg:tt)*) => {};
20}
21
22pub trait NamedDigest: Digest {
23    fn name() -> &'static str;
24}
25
26macro_rules! digest_names {
27    ($($typ:ty => $name:literal),* $(,)?) => {
28        digest_names!(
29            $(
30                $typ => $name @ $name,
31            )*
32        );
33    };
34    ($($typ:ty => $name:literal @ $feature:literal),* $(,)?) => {
35        $(
36            #[cfg(feature = $feature)]
37            impl NamedDigest for $typ {
38                fn name() -> &'static str {
39                    $name
40                }
41            }
42        )*
43    };
44}
45
46digest_names! {
47    md5::Md5 => "md5",
48    sha1::Sha1 => "sha1",
49    sha2::Sha224 => "sha224",
50    sha2::Sha256 => "sha256",
51    sha2::Sha384 => "sha384",
52    sha2::Sha512 => "sha512",
53    blake3::Hasher => "blake3",
54}
55
56// Explicitly specify the feature flags when the fn name is different
57digest_names! {
58    noncrypto_digests::Fnv => "fnv1a" @ "fnv",
59    noncrypto_digests::Xxh32 => "xxh32" @ "xxhash",
60    noncrypto_digests::Xxh64 => "xxh64" @ "xxhash",
61    noncrypto_digests::Xxh3_64 => "xxh3_64" @ "xxhash",
62    noncrypto_digests::Xxh3_128 => "xxh3_128" @ "xxhash",
63}
64
65pub(crate) fn create_hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
66    conn: &Connection,
67    fn_name: &str,
68) -> Result<()> {
69    create_scalar_function(conn, fn_name, |c| {
70        hash_fn::<T>(
71            c,
72            #[cfg(feature = "trace")]
73            "",
74        )
75        .map(HashState::finalize)
76    })?;
77
78    #[cfg(feature = "hex")]
79    {
80        let fn_name = format!("{fn_name}_hex");
81        create_scalar_function(conn, &fn_name, |c| {
82            hash_fn::<T>(
83                c,
84                #[cfg(feature = "trace")]
85                "_hex",
86            )
87            .map(HashState::finalize_hex)
88        })?;
89    }
90
91    #[cfg(feature = "aggregate")]
92    {
93        let fn_name = format!("{fn_name}_concat");
94        create_agg_function(
95            conn,
96            &fn_name,
97            crate::aggregate::AggType::<T, Vec<u8>>::new(
98                #[cfg(feature = "trace")]
99                &fn_name,
100                HashState::finalize,
101            ),
102        )?;
103    }
104
105    #[cfg(all(feature = "aggregate", feature = "hex"))]
106    {
107        let fn_name = format!("{fn_name}_concat_hex");
108        create_agg_function(
109            conn,
110            &fn_name,
111            crate::aggregate::AggType::<T, String>::new(
112                #[cfg(feature = "trace")]
113                &fn_name,
114                HashState::finalize_hex,
115            ),
116        )?;
117    }
118
119    Ok(())
120}
121
122pub fn create_scalar_function<F, T>(conn: &Connection, fn_name: &str, function: F) -> Result<()>
123where
124    // TODO: Newer versions do not require UnwindSafe
125    F: Fn(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
126    T: ToSql,
127{
128    trace!("Registering function {fn_name}");
129    conn.create_scalar_function(
130        fn_name,
131        -1,
132        FunctionFlags::SQLITE_UTF8
133            | FunctionFlags::SQLITE_DETERMINISTIC
134            | FunctionFlags::SQLITE_DIRECTONLY,
135        function,
136    )
137}
138
139fn hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
140    ctx: &Context,
141    #[cfg(feature = "trace")] suffix: &'static str,
142) -> Result<HashState<T>> {
143    let param_count = ctx.len();
144    if param_count == 0 {
145        return Err(InvalidParameterCount(param_count, 1));
146    }
147    let mut state = HashState::<T>::default();
148    for idx in 0..param_count {
149        let value = ctx.get_raw(idx);
150        match value {
151            ValueRef::Blob(val) => {
152                trace!("{}{suffix}: hashing blob arg{idx}={val:?}", T::name());
153                state.add_value(val);
154            }
155            ValueRef::Text(val) => {
156                trace!("{}{suffix}: hashing text arg{idx}={val:?}", T::name());
157                state.add_value(val);
158            }
159            ValueRef::Null => {
160                trace!("{}{suffix}: ignoring arg{idx}=NULL", T::name());
161                state.add_null();
162            }
163            #[allow(unused_variables)]
164            ValueRef::Integer(val) => {
165                trace!(
166                    "{}{suffix}: unsupported Integer arg{idx}={val:?}",
167                    T::name()
168                );
169                Err(InvalidFunctionParameterType(0, Type::Integer))?;
170            }
171            #[allow(unused_variables)]
172            ValueRef::Real(val) => {
173                trace!("{}{suffix}: unsupported Real arg{idx}={val:?}", T::name());
174                Err(InvalidFunctionParameterType(0, Type::Real))?;
175            }
176        }
177    }
178
179    Ok(state)
180}