1use std::panic::{RefUnwindSafe, UnwindSafe};
2
3use digest::Digest;
4#[cfg(feature = "trace")]
5use log::trace;
6use rusqlite::functions::Context;
7use rusqlite::ToSql;
8
9#[cfg(feature = "aggregate")]
10use crate::aggregate::create_agg_function;
11use crate::rusqlite::functions::FunctionFlags;
12use crate::rusqlite::types::{Type, ValueRef};
13use crate::rusqlite::Error::{InvalidFunctionParameterType, InvalidParameterCount};
14use crate::rusqlite::{Connection, Result};
15use crate::state::HashState;
16
17#[cfg(not(feature = "trace"))]
18macro_rules! trace {
19 ($($arg:tt)*) => {};
20}
21
22pub trait NamedDigest: Digest {
23 fn name() -> &'static str;
24}
25
26macro_rules! digest_names {
27 ($($typ:ty => $name:literal),* $(,)?) => {
28 digest_names!(
29 $(
30 $typ => $name @ $name,
31 )*
32 );
33 };
34 ($($typ:ty => $name:literal @ $feature:literal),* $(,)?) => {
35 $(
36 #[cfg(feature = $feature)]
37 impl NamedDigest for $typ {
38 fn name() -> &'static str {
39 $name
40 }
41 }
42 )*
43 };
44}
45
46digest_names! {
47 md5::Md5 => "md5",
48 sha1::Sha1 => "sha1",
49 sha2::Sha224 => "sha224",
50 sha2::Sha256 => "sha256",
51 sha2::Sha384 => "sha384",
52 sha2::Sha512 => "sha512",
53}
54
55digest_names! {
56 noncrypto_digests::Fnv => "fnv1a" @ "fnv",
57 noncrypto_digests::Xxh32 => "xxh32" @ "xxhash",
58 noncrypto_digests::Xxh64 => "xxh64" @ "xxhash",
59 noncrypto_digests::Xxh3_64 => "xxh3_64" @ "xxhash",
60 noncrypto_digests::Xxh3_128 => "xxh3_128" @ "xxhash",
61}
62
63pub(crate) fn create_hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
64 conn: &Connection,
65 fn_name: &str,
66) -> Result<()> {
67 create_scalar_function(conn, fn_name, |c| {
68 hash_fn::<T>(
69 c,
70 #[cfg(feature = "trace")]
71 "",
72 )
73 .map(HashState::finalize)
74 })?;
75
76 #[cfg(feature = "hex")]
77 {
78 let fn_name = format!("{fn_name}_hex");
79 create_scalar_function(conn, &fn_name, |c| {
80 hash_fn::<T>(
81 c,
82 #[cfg(feature = "trace")]
83 "_hex",
84 )
85 .map(HashState::finalize_hex)
86 })?;
87 }
88
89 #[cfg(feature = "aggregate")]
90 {
91 let fn_name = format!("{fn_name}_concat");
92 create_agg_function(
93 conn,
94 &fn_name,
95 crate::aggregate::AggType::<T, Vec<u8>>::new(
96 #[cfg(feature = "trace")]
97 &fn_name,
98 HashState::finalize,
99 ),
100 )?;
101 }
102
103 #[cfg(all(feature = "aggregate", feature = "hex"))]
104 {
105 let fn_name = format!("{fn_name}_concat_hex");
106 create_agg_function(
107 conn,
108 &fn_name,
109 crate::aggregate::AggType::<T, String>::new(
110 #[cfg(feature = "trace")]
111 &fn_name,
112 HashState::finalize_hex,
113 ),
114 )?;
115 }
116
117 Ok(())
118}
119
120pub fn create_scalar_function<F, T>(conn: &Connection, fn_name: &str, function: F) -> Result<()>
121where
122 F: Fn(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
124 T: ToSql,
125{
126 trace!("Registering function {fn_name}");
127 conn.create_scalar_function(
128 fn_name,
129 -1,
130 FunctionFlags::SQLITE_UTF8
131 | FunctionFlags::SQLITE_DETERMINISTIC
132 | FunctionFlags::SQLITE_DIRECTONLY,
133 function,
134 )
135}
136
137fn hash_fn<T: NamedDigest + Clone + UnwindSafe + RefUnwindSafe + 'static>(
138 ctx: &Context,
139 #[cfg(feature = "trace")] suffix: &'static str,
140) -> Result<HashState<T>> {
141 let param_count = ctx.len();
142 if param_count == 0 {
143 return Err(InvalidParameterCount(param_count, 1));
144 }
145 let mut state = HashState::<T>::default();
146 for idx in 0..param_count {
147 let value = ctx.get_raw(idx);
148 match value {
149 ValueRef::Blob(val) => {
150 trace!("{}{suffix}: hashing blob arg{idx}={val:?}", T::name());
151 state.add_value(val);
152 }
153 ValueRef::Text(val) => {
154 trace!("{}{suffix}: hashing text arg{idx}={val:?}", T::name());
155 state.add_value(val);
156 }
157 ValueRef::Null => {
158 trace!("{}{suffix}: ignoring arg{idx}=NULL", T::name());
159 state.add_null();
160 }
161 #[allow(unused_variables)]
162 ValueRef::Integer(val) => {
163 trace!(
164 "{}{suffix}: unsupported Integer arg{idx}={val:?}",
165 T::name()
166 );
167 Err(InvalidFunctionParameterType(0, Type::Integer))?;
168 }
169 #[allow(unused_variables)]
170 ValueRef::Real(val) => {
171 trace!("{}{suffix}: unsupported Real arg{idx}={val:?}", T::name());
172 Err(InvalidFunctionParameterType(0, Type::Real))?;
173 }
174 }
175 }
176
177 Ok(state)
178}