1use digest::{ExtendableOutput, Update, XofReader};
2use num_bigint::BigUint;
3use std::vec;
4use std::vec::Vec;
5use subtle::Choice;
6use subtle::ConditionallySelectable;
7
8pub struct MWFDH<H, C>
9where
10 H: ExtendableOutput + Update + Default + Clone,
11 C: Fn(&[u8]) -> Choice,
12{
13 iterations: usize,
14 output_size: usize,
15 domain_function: C,
16 inner_hash: H,
17}
18
19impl<H, C> MWFDH<H, C>
20where
21 H: ExtendableOutput + Update + Default + Clone,
22 C: Fn(&[u8]) -> Choice,
23{
24 pub fn new(iterations: usize, output_size: usize, domain_function: C) -> Self {
25 MWFDH {
26 iterations,
27 output_size,
28 domain_function,
29 inner_hash: H::default(),
30 }
31 }
32
33 pub fn input(&mut self, input: &[u8]) {
34 self.inner_hash.update(input);
35 }
36
37 pub fn results_in_domain(&self) -> Result<Vec<u8>, ()> {
38 let mut all_candidates = self.all_candidates();
39
40 let mut selection: u32 = 0;
41 let mut in_domain: Choice = 0.into();
42 for candidate in all_candidates.iter() {
43 in_domain |= (self.domain_function)(candidate);
44 let selection_plus_one = selection + 1;
45 selection.conditional_assign(&selection_plus_one, !in_domain);
46 }
47
48 let found_domain: bool = in_domain.into();
49 if !found_domain {
50 return Err(());
51 }
52
53 let result: Vec<u8> = all_candidates.remove(selection as usize);
55 Ok(result)
56 }
57
58 fn all_candidates(&self) -> Vec<Vec<u8>> {
59 let inner_hash = self.inner_hash.clone();
60 let mut reader = inner_hash.finalize_xof();
61 let underlying_size = self.output_size * (self.iterations);
62 let mut result: Vec<u8> = vec![0x00; underlying_size];
63 reader.read(&mut result);
64
65 compute_candidates(result, self.output_size, self.iterations)
66 }
67}
68
69fn compute_candidates(
72 input: Vec<u8>,
73 moving_window_size: usize,
74 num_iterations: usize,
75) -> Vec<Vec<u8>> {
76 let mut all_candidates = Vec::<Vec<u8>>::with_capacity(num_iterations);
78
79 for i in 0..num_iterations {
80 all_candidates.push(input[i..moving_window_size + i].to_vec());
81 }
82
83 all_candidates
84}
85
86pub fn between(check: &[u8], min: &BigUint, max: &BigUint) -> Choice {
88 let check = BigUint::from_bytes_be(check);
89 Choice::from((&check < max) as u8) & Choice::from((&check > min) as u8)
90}
91
92pub fn lt(check: &[u8], max: &BigUint) -> Choice {
94 let check = BigUint::from_bytes_be(check);
95 ((&check < max) as u8).into()
96}
97
98pub fn gt(check: &[u8], min: &BigUint) -> Choice {
100 let check = BigUint::from_bytes_be(check);
101 ((&check > min) as u8).into()
102}
103
104#[cfg(test)]
124mod tests {
125 use super::*;
126 use sha3::Shake128;
127
128 #[test]
129 fn all_candidates_test() {
130 let input = vec![0, 1, 2, 3, 4, 5, 6, 7];
131 let candidates = compute_candidates(input, 5, 4);
132
133 assert_eq!(candidates[0], vec![0, 1, 2, 3, 4]);
134 assert_eq!(candidates[1], vec![1, 2, 3, 4, 5]);
135 assert_eq!(candidates[2], vec![2, 3, 4, 5, 6]);
136 assert_eq!(candidates[3], vec![3, 4, 5, 6, 7]);
137 }
138
139 #[test]
140 fn test_results_between() {
141 use num_bigint::BigUint;
142 use num_traits::Num;
143
144 let min = BigUint::from_str_radix(
145 "51683095453715361952842063988888814558178328011011413557662527675023521115731",
146 10,
147 )
148 .unwrap();
149 let max = BigUint::from_str_radix(
150 "63372381656167118369940880608146415619543459354936568979731399163319071519847",
151 10,
152 )
153 .unwrap();
154
155 let mut hasher = MWFDH::<Shake128, _>::new(2048, 32, |check: _| between(check, &min, &max));
156
157 hasher.input(b"ATTACK AT DAWN");
158 let result = hasher.results_in_domain().unwrap();
159
160 assert_eq!(
161 hex::encode(&result),
162 "7ebe111e3d443145d87f7b574f67f92be291f19d747a489601e40bd6f3671008"
163 );
164
165 let result_bigint = BigUint::from_bytes_be(&result).to_str_radix(10);
166 assert_eq!(
167 result_bigint,
168 "57327238008737855959412403183414616474281863704162301159073898079241428733960"
169 );
170 }
171
172 #[test]
173 fn test_results_lt() {
174 use num_bigint::BigUint;
175 use num_traits::Num;
176
177 let max = BigUint::from_str_radix(
178 "23372381656167118369940880608146415619543459354936568979731399163319071519847",
179 10,
180 )
181 .unwrap();
182
183 let mut hasher = MWFDH::<Shake128, _>::new(2048, 32, |check: _| lt(check, &max));
184
185 hasher.input(b"ATTACK AT DAWN");
186 let result = hasher.results_in_domain().unwrap();
187 assert_eq!(
188 hex::encode(&result),
189 "111e3d443145d87f7b574f67f92be291f19d747a489601e40bd6f36710080831"
190 );
191
192 let result_bigint = BigUint::from_bytes_be(&result).to_str_radix(10);
193 assert_eq!(
194 result_bigint,
195 "7742746682851442867075436372447051338297254606827936826213800416869211441201"
196 );
197 }
198
199 #[test]
200 fn test_results_gt() {
201 use num_bigint::BigUint;
202 use num_traits::Num;
203
204 let min = BigUint::from_str_radix(
205 "81683095453715361952842063988888814558178328011011413557662527675023521115731",
206 10,
207 )
208 .unwrap();
209
210 let mut hasher = MWFDH::<Shake128, _>::new(2048, 32, |check: _| gt(check, &min));
211
212 hasher.input(b"ATTACK AT DAWN");
213 let result = hasher.results_in_domain().unwrap();
214 assert_eq!(
215 hex::encode(&result),
216 "be111e3d443145d87f7b574f67f92be291f19d747a489601e40bd6f367100808"
217 );
218
219 let result_bigint = BigUint::from_bytes_be(&result).to_str_radix(10);
220 assert_eq!(
221 result_bigint,
222 "85969686335050502239631103859465427904139040394838027751262323288751421261832"
223 );
224 }
225}