sqlite_hashes/
scalar.rs

1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3use digest::Digest;
4#[cfg(feature = "trace")]
5use log::trace;
6use rusqlite::functions::Context;
7use rusqlite::ToSql;
8
9#[cfg(feature = "aggregate")]
10use crate::aggregate::create_agg_function;
11use crate::rusqlite::functions::FunctionFlags;
12use crate::rusqlite::types::{Type, ValueRef};
13use crate::rusqlite::Error::{InvalidFunctionParameterType, InvalidParameterCount};
14use crate::rusqlite::{Connection, Result};
15use crate::state::HashState;
16
17#[cfg(not(feature = "trace"))]
18macro_rules! trace {
19    ($($arg:tt)*) => {};
20}
21
22pub trait NamedDigest: Digest {
23    fn name() -> &'static str;
24}
25
26macro_rules! digest_names {
27    ($($typ:ty => $name:literal),* $(,)?) => {
28        digest_names!(
29            $(
30                $typ => $name @ $name,
31            )*
32        );
33    };
34    ($($typ:ty => $name:literal @ $feature:literal),* $(,)?) => {
35        $(
36            #[cfg(feature = $feature)]
37            impl NamedDigest for $typ {
38                fn name() -> &'static str {
39                    $name
40                }
41            }
42        )*
43    };
44}
45
46digest_names! {
47    md5::Md5 => "md5",
48    sha1::Sha1 => "sha1",
49    sha2::Sha224 => "sha224",
50    sha2::Sha256 => "sha256",
51    sha2::Sha384 => "sha384",
52    sha2::Sha512 => "sha512",
53}
54
55digest_names! {
56    noncrypto_digests::Fnv => "fnv1a" @ "fnv",
57    noncrypto_digests::Xxh32 => "xxh32" @ "xxhash",
58    noncrypto_digests::Xxh64 => "xxh64" @ "xxhash",
59    noncrypto_digests::Xxh3_64 => "xxh3_64" @ "xxhash",
60    noncrypto_digests::Xxh3_128 => "xxh3_128" @ "xxhash",
61}
62
63pub(crate) fn create_hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
64    conn: &Connection,
65    fn_name: &str,
66) -> Result<()> {
67    create_scalar_function(conn, fn_name, |c| {
68        hash_fn::<T>(
69            c,
70            #[cfg(feature = "trace")]
71            "",
72        )
73        .map(HashState::finalize)
74    })?;
75
76    #[cfg(feature = "hex")]
77    {
78        let fn_name = format!("{fn_name}_hex");
79        create_scalar_function(conn, &fn_name, |c| {
80            hash_fn::<T>(
81                c,
82                #[cfg(feature = "trace")]
83                "_hex",
84            )
85            .map(HashState::finalize_hex)
86        })?;
87    }
88
89    #[cfg(feature = "aggregate")]
90    {
91        let fn_name = format!("{fn_name}_concat");
92        create_agg_function(
93            conn,
94            &fn_name,
95            crate::aggregate::AggType::<T, Vec<u8>>::new(
96                #[cfg(feature = "trace")]
97                &fn_name,
98                HashState::finalize,
99            ),
100        )?;
101    }
102
103    #[cfg(all(feature = "aggregate", feature = "hex"))]
104    {
105        let fn_name = format!("{fn_name}_concat_hex");
106        create_agg_function(
107            conn,
108            &fn_name,
109            crate::aggregate::AggType::<T, String>::new(
110                #[cfg(feature = "trace")]
111                &fn_name,
112                HashState::finalize_hex,
113            ),
114        )?;
115    }
116
117    Ok(())
118}
119
120pub fn create_scalar_function<F, T>(conn: &Connection, fn_name: &str, function: F) -> Result<()>
121where
122    // TODO: Newer versions do not require UnwindSafe
123    F: Fn(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
124    T: ToSql,
125{
126    trace!("Registering function {fn_name}");
127    conn.create_scalar_function(
128        fn_name,
129        -1,
130        FunctionFlags::SQLITE_UTF8
131            | FunctionFlags::SQLITE_DETERMINISTIC
132            | FunctionFlags::SQLITE_DIRECTONLY,
133        function,
134    )
135}
136
137fn hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
138    ctx: &Context,
139    #[cfg(feature = "trace")] suffix: &'static str,
140) -> Result<HashState<T>> {
141    let param_count = ctx.len();
142    if param_count == 0 {
143        return Err(InvalidParameterCount(param_count, 1));
144    }
145    let mut state = HashState::<T>::default();
146    for idx in 0..param_count {
147        let value = ctx.get_raw(idx);
148        match value {
149            ValueRef::Blob(val) => {
150                trace!("{}{suffix}: hashing blob arg{idx}={val:?}", T::name());
151                state.add_value(val);
152            }
153            ValueRef::Text(val) => {
154                trace!("{}{suffix}: hashing text arg{idx}={val:?}", T::name());
155                state.add_value(val);
156            }
157            ValueRef::Null => {
158                trace!("{}{suffix}: ignoring arg{idx}=NULL", T::name());
159                state.add_null();
160            }
161            #[allow(unused_variables)]
162            ValueRef::Integer(val) => {
163                trace!(
164                    "{}{suffix}: unsupported Integer arg{idx}={val:?}",
165                    T::name()
166                );
167                Err(InvalidFunctionParameterType(0, Type::Integer))?;
168            }
169            #[allow(unused_variables)]
170            ValueRef::Real(val) => {
171                trace!("{}{suffix}: unsupported Real arg{idx}={val:?}", T::name());
172                Err(InvalidFunctionParameterType(0, Type::Real))?;
173            }
174        }
175    }
176
177    Ok(state)
178}