yume_pdq/
lib.rs

1#![doc = include_str!("../README.md")]
2/*
3 * Copyright (c) 2025 Yumechi <yume@yumechi.jp>
4 *
5 * Created on Saturday, March 22, 2025
6 * Author: Yumechi <yume@yumechi.jp>
7 *
8 * SPDX-License-Identifier: Apache-2.0
9 *
10 * Licensed under the Apache License, Version 2.0 (the "License");
11 * you may not use this file except in compliance with the License.
12 * You may obtain a copy of the License at
13 *
14 * http://www.apache.org/licenses/LICENSE-2.0
15 *
16 * Unless required by applicable law or agreed to in writing, software
17 * distributed under the License is distributed on an "AS IS" BASIS,
18 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 * See the License for the specific language governing permissions and
20 * limitations under the License.
21 */
22#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
23#![cfg_attr(feature = "portable-simd", feature(portable_simd))]
24#![warn(missing_docs, clippy::pedantic)]
25#![allow(
26    clippy::type_complexity,
27    clippy::missing_errors_doc,
28    clippy::doc_markdown,
29    clippy::similar_names,
30    clippy::cast_lossless
31)]
32#![allow(clippy::inline_always)]
33#![allow(
34    clippy::bool_to_int_with_if,
35    reason = "I don't know, I think it's more readable"
36)]
37
38#[cfg(feature = "alloc")]
39extern crate alloc;
40
41pub use const_default::{self, ConstDefault};
42pub use generic_array::{self, GenericArray};
43pub use num_traits;
44pub use zeroize;
45
46use kernel::{
47    Kernel,
48    threshold::threshold_2d_f32,
49    type_traits::{DivisibleBy8, EvaluateHardwareFeature, SquareOf},
50};
51
52use generic_array::{
53    ArrayLength,
54    typenum::{B1, IsLessOrEqual, U16, U32},
55};
56
57/// PDQ compression kernel
58pub mod kernel;
59
60/// PDQ matching solution
61///
62/// Currently all solutions are exact linear-scan nearest neighbor thresholding and are expected to continue to be so
63///
64/// Metric-tree based solutions such as BK-tree and KD-tree are not efficient due to unique characteristics of PDQ hash and dihedral invariance necessitating all screens to match 8 hashes at once. See [TECHNICAL.md](TECHNICAL.md) for more details.
65///
66/// ANN will lead to significant, guaranteed false negatives (unlike my DISC21 benchmark shows 2 outliers (still well within threshold) does not mean guaranteed <98% recall).
67/// Experiment using Facebook(R) Faiss IndexBinaryHNSW on real NEMEC PDQ data shows 90% recall with nearing 10ms per query single-threaded.
68/// Even if one can accept this recall (one shouldn't), performance is still not competitive with any optimized matcher here.
69pub mod matching;
70
71pub use kernel::smart_kernel;
72
73/// Memory alignment utilities.
74pub mod alignment;
75
76/// Diagnostic utilities for debugging, integrating developers, or generally for fun inspecting internals. Not part of the stable API. Correctness is only checked empirically.
77#[cfg(any(test, all(feature = "unstable", feature = "std")))]
78pub mod testing;
79
80#[cfg(target_arch = "wasm32")]
81/// WASM bindings.
82/// cbindgen:ignore
83pub mod wasm;
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86#[repr(C)]
87/// A packed representation of a matrix for dihedral transformations.
88pub struct Dihedrals {
89    /// The packed representation of the dihedral matrix.
90    ///
91    /// Ordering is first x-to-x, then x-to-y, then y-to-x, then y-to-y. Big-endian signed 8-bit integers packed into a u32.
92    pub packed: u32,
93}
94
95impl Dihedrals {
96    /// Create a new dihedral from a tuple of tuples.
97    #[allow(
98        clippy::cast_possible_truncation,
99        clippy::identity_op,
100        clippy::cast_sign_loss
101    )]
102    #[must_use]
103    pub const fn from_tuples(dx: (i8, i8), dy: (i8, i8)) -> Self {
104        Self {
105            packed: u32::from_be_bytes([(dx.0 as u8), (dx.1 as u8), (dy.0 as u8), (dy.1 as u8)]),
106        }
107    }
108
109    /// Convert the dihedral to a tuple of tuples.
110    #[allow(clippy::cast_possible_truncation, clippy::identity_op)]
111    #[must_use]
112    pub const fn into_tuples(self) -> ((i8, i8), (i8, i8)) {
113        let (dx0, dx1) = (self.packed >> 24 & 0xFF, self.packed >> 16 & 0xFF);
114        let (dy0, dy1) = (self.packed >> 8 & 0xFF, self.packed >> 0 & 0xFF);
115        ((dx0 as i8, dx1 as i8), (dy0 as i8, dy1 as i8))
116    }
117
118    /// The normal dihedral transformation.
119    pub const NORMAL: Self = Self::from_tuples((1, 0), (0, 1));
120    /// The flipped dihedral transformation.
121    pub const FLIPPED: Self = Self::from_tuples((1, 0), (0, -1));
122    /// The flopped dihedral transformation.
123    pub const FLOPPED: Self = Self::from_tuples((-1, 0), (0, 1));
124    /// The 180-degree rotated dihedral transformation.
125    pub const ROTATED_180: Self = Self::from_tuples((-1, 0), (0, -1));
126    /// The 90-degree rotated dihedral transformation.
127    pub const ROTATED_90: Self = Self::from_tuples((0, 1), (-1, 0));
128    /// The 270-degree rotated dihedral transformation.
129    pub const ROTATED_270: Self = Self::from_tuples((0, 1), (1, 0));
130    /// The 90-degree flopped dihedral transformation.
131    pub const ROTATED_90_FLOPPED: Self = Self::from_tuples((0, -1), (-1, 0));
132    /// The 270-degree flopped dihedral transformation.
133    pub const FLOPPED_ROTATED_270: Self = Self::from_tuples((0, -1), (1, 0));
134}
135
136#[cfg(feature = "ffi")]
137#[allow(clippy::transmute_ptr_to_ptr, clippy::transmute_ptr_to_ref)]
138pub mod ffi {
139    //! Foreign function interface binding for the PDQ hash function.
140    //!
141    //! There is no guarantee of Rust-level API compatibility in this module.
142    use generic_array::{sequence::Unflatten, typenum::U32};
143
144    #[allow(clippy::wildcard_imports)]
145    use super::*;
146    use crate::kernel::{SmartKernelConcreteType, SquareGenericArrayExt, smart_kernel_impl};
147    use core::ffi::c_void;
148    use std::sync::LazyLock;
149
150    include!(concat!(env!("OUT_DIR"), "/version_ffi.rs"));
151
152    static SMART_KERNEL: LazyLock<SmartKernelConcreteType> = LazyLock::new(smart_kernel_impl);
153
154    /// re-exported constants for the dihedrals
155    #[unsafe(no_mangle)]
156    pub static YUME_PDQ_DIHEDRAL_NORMAL: Dihedrals = Dihedrals::NORMAL;
157    /// re-exported constants for the dihedrals
158    #[unsafe(no_mangle)]
159    pub static YUME_PDQ_DIHEDRAL_FLIPPED: Dihedrals = Dihedrals::FLIPPED;
160    /// re-exported constants for the dihedrals
161    #[unsafe(no_mangle)]
162    pub static YUME_PDQ_DIHEDRAL_FLOPPED: Dihedrals = Dihedrals::FLOPPED;
163    /// re-exported constants for the dihedrals
164    #[unsafe(no_mangle)]
165    pub static YUME_PDQ_DIHEDRAL_ROTATED_180: Dihedrals = Dihedrals::ROTATED_180;
166    /// re-exported constants for the dihedrals
167    #[unsafe(no_mangle)]
168    pub static YUME_PDQ_DIHEDRAL_ROTATED_90: Dihedrals = Dihedrals::ROTATED_90;
169    /// re-exported constants for the dihedrals
170    #[unsafe(no_mangle)]
171    pub static YUME_PDQ_DIHEDRAL_ROTATED_270: Dihedrals = Dihedrals::ROTATED_270;
172    /// re-exported constants for the dihedrals
173    #[unsafe(no_mangle)]
174    pub static YUME_PDQ_DIHEDRAL_ROTATED_90_FLOPPED: Dihedrals = Dihedrals::ROTATED_90_FLOPPED;
175    /// re-exported constants for the dihedrals
176    #[unsafe(no_mangle)]
177    pub static YUME_PDQ_DIHEDRAL_FLOPPED_ROTATED_270: Dihedrals = Dihedrals::FLOPPED_ROTATED_270;
178
179    /// A callback function for visiting all dihedrals.
180    ///
181    /// The threshold, PDQF and quantized output will be available to the caller via the provided buffers ONLY before the callback returns.
182    ///
183    /// Return true to continue, false to stop.
184    ///
185    /// The function must not modify the buffers, and must copy them out before returning if they need to keep them.
186    pub type DihedralCallback =
187        extern "C" fn(ctx: *mut c_void, dihedral: u32, threshold: f32, quality: f32) -> bool;
188
189    #[unsafe(export_name = "yume_pdq_visit_dihedrals_smart_kernel")]
190    /// Visit the 7 alternative dihedrals of the PDQF hash.
191    ///
192    /// # Safety
193    ///
194    /// - `ctx` is transparently passed to the callback function.
195    /// - `threshold` must be a valid threshold value for the provided PDQF input received from [`hash_smart_kernel`].
196    /// - `output` is out only, must be a pointer to a 2x16 array of u8 to receive any intermediate 256-bit hash. It does not have to be initialized to any particular value.
197    /// - `pdqf` is in/out, must be a pointer to a 16x16 array of f32 values of the initial PDQF data, and be writable to receive derived PDQF (unquantized) hash values.
198    /// - `callback` must be a valid callback function that will be called for each dihedral.
199    ///
200    /// No buffer should overlap.
201    ///
202    /// # Returns
203    ///
204    /// - `true` if all dihedrals were visited, `false` if the callback returned false for any dihedral.
205    pub unsafe extern "C" fn visit_dihedrals_smart_kernel(
206        ctx: *mut c_void,
207        threshold: f32,
208        output: *mut u8,
209        pdqf: *mut f32,
210        callback: DihedralCallback,
211    ) -> bool {
212        let output = unsafe { core::mem::transmute::<*mut u8, &mut [u8; 2 * 16]>(output) };
213        let pdqf = unsafe { core::mem::transmute::<*mut f32, &mut [f32; 16 * 16]>(pdqf) };
214
215        let pdqf = GenericArray::from_mut_slice(pdqf).unflatten_square_mut();
216        let output = GenericArray::<_, U32>::from_mut_slice(output).unflatten();
217
218        crate::visit_dihedrals(
219            &mut SMART_KERNEL.clone(),
220            pdqf,
221            output,
222            threshold,
223            |dihedral, _, (quality, _pdqf, _output)| {
224                if callback(ctx, dihedral.packed, threshold, quality) {
225                    Ok(())
226                } else {
227                    Err(())
228                }
229            },
230        )
231        .is_ok()
232    }
233
234    #[unsafe(export_name = "yume_pdq_hash_smart_kernel")]
235    /// Compute the PDQ hash of a 512x512 single-channel image using [`kernel::smart_kernel`].
236    ///
237    /// # Safety
238    ///
239    /// - `input` is in only, must be a pointer to a 512x512 single-channel image in float32 format, row-major order.
240    /// - `threshold` is out only, must be a valid aligned pointer to a f32 value or NULL.
241    /// - `output` is out only, must be a pointer to a 2x16 array of u8 to receive the final 256-bit hash.
242    /// - `buf1` is in/out, must be a pointer to a 128x128 array of f32 values to receive the intermediate results of the DCT transform.
243    /// - `tmp` is in/out, must be a pointer to a 128x1 array of f32 values as scratch space for the DCT transform.
244    /// - `pdqf` is out only, must be a pointer to a 16x16 array of f32 values to receive PDQF (unquantized) hash values.
245    ///
246    /// No buffer should overlap.
247    ///
248    /// # Returns
249    ///
250    /// The quality of the hash as a f32 value between 0.0 and 1.0. You are responsible for checking whether quality is acceptable.
251    pub unsafe extern "C" fn hash_smart_kernel(
252        input: *const f32,
253        threshold: *mut f32,
254        output: *mut u8,
255        buf1: *mut f32,
256        tmp: *mut f32,
257        pdqf: *mut f32,
258    ) -> f32 {
259        let input = unsafe { core::mem::transmute::<*const f32, &[f32; 512 * 512]>(input) };
260        let output = unsafe { core::mem::transmute::<*mut u8, &mut [u8; 2 * 16]>(output) };
261        let buf1 = unsafe { core::mem::transmute::<*mut f32, &mut [f32; 128 * 128]>(buf1) };
262        let tmp = unsafe { core::mem::transmute::<*mut f32, &mut [f32; 128]>(tmp) };
263        let pdqf = unsafe { core::mem::transmute::<*mut f32, &mut [f32; 16 * 16]>(pdqf) };
264
265        #[allow(clippy::clone_on_copy)]
266        let mut kernel = SMART_KERNEL.clone();
267        let input = GenericArray::from_slice(input).unflatten_square_ref();
268        let output = GenericArray::<_, U32>::from_mut_slice(output).unflatten();
269        let buf1 = GenericArray::from_mut_slice(buf1).unflatten_square_mut();
270        let pdqf = GenericArray::from_mut_slice(pdqf).unflatten_square_mut();
271
272        let mut dummy_threshold = 0.0;
273
274        crate::hash_get_threshold(
275            &mut kernel,
276            input,
277            unsafe { threshold.as_mut().unwrap_or(&mut dummy_threshold) },
278            output,
279            buf1,
280            tmp.into(),
281            pdqf,
282        )
283    }
284}
285
286/// PDQ hash type
287pub type PDQHash<L = U16> = GenericArray<GenericArray<u8, <L as DivisibleBy8>::Output>, L>;
288
289/// Unquantized PDQ hash ("PDQF" in the original paper)
290pub type PDQHashF<N = f32, L = U16> = GenericArray<GenericArray<N, L>, L>;
291
292/// Compute the PDQ hash of a 512x512 single-channel image using the given kernel.
293///
294/// This is a convenience wrapper function and just calls [`hash_get_threshold`] with a dummy output location.
295///
296/// **Warning**: While it may be tempting, DO NOT pass uninitialized memory into any parameter of this function.
297/// While the contents are not important, the padding must be zero-initialized otherwise subtly incorrect results will be returned.
298///
299/// # TLDR how to use this contraption
300///
301/// ```rust,no_run
302/// use yume_pdq::{smart_kernel, GenericArray};
303///
304/// // Create a 512x512 input image
305/// //
306/// // values 0.0-255.0 if you want the quality for be accurate, otherwise scale is not important
307/// // this is a known limitation and will be fixed in the future
308/// let input: GenericArray<GenericArray<f32, _>, _> = GenericArray::default();
309///
310/// // Get the optimal kernel for your CPU
311/// let mut kernel = smart_kernel();
312///
313/// // Allocate output and temporary buffers (make sure your stack is big enough or allocate on the heap)
314/// let mut output = GenericArray::default();  // Will contain the final 256-bit hash
315/// let mut buf1 = GenericArray::default();    // Temporary buffer
316/// let mut row_tmp = GenericArray::default();    // Temporary buffer
317/// let mut pdqf = GenericArray::default();    // Temporary buffer (PDQF unquantized hash)
318///
319/// // Compute the hash
320/// let quality = yume_pdq::hash(&mut kernel, &input, &mut output, &mut buf1, &mut row_tmp, &mut pdqf);
321#[inline]
322pub fn hash<K: Kernel>(
323    kernel: &mut K,
324    input: &GenericArray<GenericArray<f32, K::InputDimension>, K::InputDimension>,
325    output: &mut GenericArray<
326        GenericArray<u8, <K::OutputDimension as DivisibleBy8>::Output>,
327        K::OutputDimension,
328    >,
329    buf1: &mut GenericArray<GenericArray<K::InternalFloat, K::Buffer1WidthX>, K::Buffer1LengthY>,
330    tmp: &mut GenericArray<K::InternalFloat, K::Buffer1WidthX>,
331    // the floating point version of the input image
332    pdqf: &mut PDQHashF<K::InternalFloat, K::OutputDimension>,
333) -> f32
334where
335    <K as Kernel>::OutputDimension: DivisibleBy8,
336    <K as Kernel>::InputDimension: SquareOf,
337    <<K as Kernel>::InputDimension as SquareOf>::Output: ArrayLength,
338    <K as Kernel>::OutputDimension: SquareOf,
339    <<K as Kernel>::OutputDimension as SquareOf>::Output: ArrayLength,
340    <K as Kernel>::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
341{
342    hash_get_threshold(
343        kernel,
344        input,
345        &mut Default::default(),
346        output,
347        buf1,
348        tmp,
349        pdqf,
350    )
351}
352
353/// Visit the 7 alternative dihedrals of the PDQF hash.
354///
355/// The callback function is called with first the matrix of the dihedral, then the threshold, then the quality, then the PDQF hash and finally the output hash.
356///
357/// The PDQF hash and output hash are guaranteed to point to the same buffer as the input hash, it is just to make the borrow-checker happy.
358pub fn visit_dihedrals<
359    K: Kernel<InternalFloat = f32>,
360    E,
361    F: FnMut(
362        Dihedrals,
363        f32,
364        (
365            f32,
366            &mut PDQHashF<K::InternalFloat, K::OutputDimension>,
367            &mut PDQHash<K::OutputDimension>,
368        ),
369    ) -> Result<(), E>,
370>(
371    kernel: &mut K,
372    pdqf: &mut PDQHashF<K::InternalFloat, K::OutputDimension>,
373    output: &mut PDQHash<K::OutputDimension>,
374    threshold: K::InternalFloat,
375    mut f: F,
376) -> Result<(), E>
377where
378    K::OutputDimension: DivisibleBy8 + IsLessOrEqual<U32, Output = B1>,
379    <K as Kernel>::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
380{
381    macro_rules! callback {
382        ($dihedral:expr, $threshold:expr) => {
383            let gradient = kernel.sum_of_gradients(pdqf);
384            let quality = K::adjust_quality(gradient);
385            f($dihedral, $threshold, (quality, pdqf, output))?;
386        };
387    }
388
389    let mut threshold_negate_alt_cols = threshold;
390    let mut threshold_negate_alt_rows = threshold;
391    let mut threshold_negate_off_diagonals = threshold;
392    kernel.pdqf_negate_alt_cols::<false>(pdqf); // first negate by columns
393    kernel.quantize(pdqf, &mut threshold_negate_alt_cols, output);
394    callback!(Dihedrals::FLOPPED, threshold_negate_alt_cols);
395    kernel.pdqf_negate_alt_rows::<true>(pdqf); // then negate by rows, getting the negate-by off-diagonals
396    kernel.quantize(pdqf, &mut threshold_negate_off_diagonals, output);
397    callback!(Dihedrals::ROTATED_180, threshold_negate_off_diagonals);
398    kernel.pdqf_negate_alt_cols::<false>(pdqf); // then negate by columns again, getting the negate-by alt-rows
399    kernel.quantize(pdqf, &mut threshold_negate_alt_rows, output);
400    callback!(Dihedrals::FLIPPED, threshold_negate_alt_rows);
401    // undo all negations, transpose
402    kernel.pdqf_negate_alt_rows::<true>(pdqf);
403    kernel.pdqf_t(pdqf);
404    threshold_2d_f32(pdqf, output, threshold);
405    callback!(Dihedrals::ROTATED_90_FLOPPED, threshold);
406    // now undo the original transformations to get back to the other 3 hashes that require transposition
407    kernel.pdqf_negate_alt_rows::<true>(pdqf);
408    threshold_2d_f32(pdqf, output, threshold_negate_alt_cols);
409    callback!(Dihedrals::ROTATED_270, threshold_negate_alt_cols);
410    kernel.pdqf_negate_alt_cols::<false>(pdqf);
411    threshold_2d_f32(pdqf, output, threshold_negate_off_diagonals);
412    callback!(
413        Dihedrals::FLOPPED_ROTATED_270,
414        threshold_negate_off_diagonals
415    );
416    kernel.pdqf_negate_alt_rows::<true>(pdqf);
417    threshold_2d_f32(pdqf, output, threshold_negate_alt_rows);
418    callback!(Dihedrals::ROTATED_90, threshold_negate_alt_rows);
419    Ok(())
420}
421
422/// Compute the PDQ hash of a 512x512 single-channel image using the given kernel, obtaining the threshold value useful for [`kernel::threshold::threshold_2d_f32`].
423///
424/// **Warning**: While it may be tempting, DO NOT pass uninitialized memory into any parameter of this function.
425/// While the contents are not important, the padding must be zero-initialized otherwise subtly incorrect results will be returned.
426#[inline]
427pub fn hash_get_threshold<K: Kernel>(
428    kernel: &mut K,
429    input: &GenericArray<GenericArray<f32, K::InputDimension>, K::InputDimension>,
430    threshold: &mut K::InternalFloat,
431    output: &mut GenericArray<
432        GenericArray<u8, <K::OutputDimension as DivisibleBy8>::Output>,
433        K::OutputDimension,
434    >,
435    buf1: &mut GenericArray<GenericArray<K::InternalFloat, K::Buffer1WidthX>, K::Buffer1LengthY>,
436    tmp: &mut GenericArray<K::InternalFloat, K::Buffer1WidthX>,
437    // the floating point version of the input image
438    pdqf: &mut PDQHashF<K::InternalFloat, K::OutputDimension>,
439) -> f32
440where
441    <K as Kernel>::InputDimension: SquareOf,
442    <<K as Kernel>::InputDimension as SquareOf>::Output: ArrayLength,
443    <K as Kernel>::OutputDimension: SquareOf,
444    <<K as Kernel>::OutputDimension as SquareOf>::Output: ArrayLength,
445    <K as Kernel>::OutputDimension: DivisibleBy8,
446    <K as Kernel>::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
447{
448    kernel.jarosz_compress(input, buf1);
449    kernel.dct2d(buf1, tmp, pdqf);
450    let gradient = kernel.sum_of_gradients(pdqf);
451    let quality = K::adjust_quality(gradient);
452
453    kernel.quantize(pdqf, threshold, output);
454    quality
455}
456
457/// Compute the PDQ hash of a 512x512 single-channel image using the given kernel without quantization.
458///
459/// This is called PDQF in the original paper.
460///
461/// **Warning**: While it may be tempting, DO NOT pass uninitialized memory into any parameter of this function.
462/// While the contents are not important, the padding must be zero-initialized otherwise subtly incorrect results will be returned.
463///
464pub fn hash_float<K: Kernel>(
465    kernel: &mut K,
466    input: &GenericArray<GenericArray<f32, K::InputDimension>, K::InputDimension>,
467    output: &mut PDQHashF<K::InternalFloat, K::OutputDimension>,
468    buf1: &mut GenericArray<GenericArray<K::InternalFloat, K::Buffer1WidthX>, K::Buffer1LengthY>,
469    tmp: &mut GenericArray<K::InternalFloat, K::Buffer1WidthX>,
470) -> f32
471where
472    <K as Kernel>::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
473{
474    kernel.jarosz_compress(input, buf1);
475    kernel.dct2d(buf1, tmp, output);
476    let gradient = kernel.sum_of_gradients(output);
477
478    K::adjust_quality(gradient)
479}
480
481#[cfg(test)]
482mod tests {
483
484    use core::ops::Mul;
485
486    use generic_array::{
487        sequence::Flatten,
488        typenum::{U2, U512},
489    };
490    use pdqhash::image::{self, DynamicImage};
491
492    use crate::kernel::{
493        ReferenceKernel, SquareGenericArrayExt,
494        type_traits::{DivisibleBy8, SquareOf},
495    };
496
497    use super::*;
498
499    const TEST_IMAGE_AAA_ORIG: &[u8] = include_bytes!(concat!(
500        env!("CARGO_MANIFEST_DIR"),
501        "/test-data/aaa-orig.jpg"
502    ));
503    const TEST_IMAGE_ANIME: &[u8] =
504        include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/test-data/anime.png"));
505    const TEST_IMAGE_MUSIC: &[u8] =
506        include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/test-data/music.png"));
507    const TEST_IMAGE_NEOFETCH: &[u8] = include_bytes!(concat!(
508        env!("CARGO_MANIFEST_DIR"),
509        "/test-data/neofetch.png"
510    ));
511
512    fn test_hash_impl_lib<K: Kernel>(kernel: &mut K)
513    where
514        K: Kernel<OutputDimension = U16, InputDimension = U512>,
515        K::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
516    {
517        let mut buf1 = Box::default();
518        let mut buf2 = Box::default();
519
520        for (name, image_data) in [
521            ("aaa-orig.jpg", TEST_IMAGE_AAA_ORIG),
522            ("anime.png", TEST_IMAGE_ANIME),
523            ("music.png", TEST_IMAGE_MUSIC),
524            ("neofetch.png", TEST_IMAGE_NEOFETCH),
525        ] {
526            let input = image::load_from_memory(image_data).unwrap();
527
528            let input = input.resize_exact(512, 512, image::imageops::FilterType::Triangle);
529
530            let input_image = input.to_luma8();
531
532            let input_image_f = input_image
533                .as_raw()
534                .iter()
535                .map(|p| *p as f32)
536                .collect::<Vec<_>>();
537
538            let output_expected =
539                pdqhash::generate_pdq_full_size(&DynamicImage::ImageLuma8(input_image));
540
541            let mut output = GenericArray::default();
542
543            let quality = hash(
544                kernel,
545                GenericArray::<_, _>::from_slice(input_image_f.as_slice()).unflatten_square_ref(),
546                &mut output,
547                &mut buf1,
548                &mut GenericArray::default(),
549                &mut buf2,
550            );
551
552            let mut distance = 0;
553            for (a, b) in output.flatten().iter().zip(output_expected.0.iter()) {
554                let bits_diff = (a ^ b).count_ones();
555                distance += bits_diff;
556            }
557
558            println!(
559                "[{} ({})] {}: Distance vs. library: {}/{} (Qin={}, Qout={})",
560                std::any::type_name::<K>(),
561                std::any::type_name::<K::InternalFloat>(),
562                name,
563                distance,
564                16 * 16,
565                output_expected.1,
566                quality
567            );
568
569            // this is demo "bad" picture for hashing, highly malleable output is expected
570            // when there is any non-prescribed preprocessing happening
571            if name != "neofetch.png" {
572                // half of the matching threshold
573                assert!(distance <= 16);
574            }
575        }
576    }
577
578    fn test_hash_impl_ref<
579        ID: ArrayLength + SquareOf,
580        OD: ArrayLength + SquareOf + Mul<OD>,
581        K: Kernel<InputDimension = ID, OutputDimension = OD>,
582    >(
583        kernel: &mut K,
584    ) where
585        OD: DivisibleBy8,
586        <ID as SquareOf>::Output: ArrayLength,
587        <OD as SquareOf>::Output: ArrayLength,
588        <OD as Mul<OD>>::Output: ArrayLength,
589        ReferenceKernel<K::InternalFloat>:
590            Kernel<InputDimension = ID, InternalFloat = K::InternalFloat, OutputDimension = OD>,
591        <ReferenceKernel<K::InternalFloat> as Kernel>::RequiredHardwareFeature:
592            EvaluateHardwareFeature<EnabledStatic = B1>,
593        K::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
594    {
595        let mut buf1 = Box::default();
596        let mut buf1a = Box::default();
597        let mut buf2 = Box::default();
598
599        for (name, image_data) in [
600            ("aaa-orig.jpg", TEST_IMAGE_AAA_ORIG),
601            ("anime.png", TEST_IMAGE_ANIME),
602            ("music.png", TEST_IMAGE_MUSIC),
603            ("neofetch.png", TEST_IMAGE_NEOFETCH),
604        ] {
605            let input = image::load_from_memory(image_data).unwrap();
606
607            let input = input.resize_exact(
608                ID::USIZE as _,
609                ID::USIZE as _,
610                image::imageops::FilterType::Triangle,
611            );
612
613            let input_image = input
614                .to_luma8()
615                .iter()
616                .map(|p| *p as f32)
617                .collect::<Vec<_>>();
618
619            let mut output = GenericArray::default();
620            let mut output_ref = GenericArray::default();
621
622            hash(
623                kernel,
624                GenericArray::from_slice(input_image.as_slice()).unflatten_square_ref(),
625                &mut output,
626                &mut buf1,
627                &mut GenericArray::default(),
628                &mut buf2,
629            );
630            let mut ref_kernel = ReferenceKernel::<K::InternalFloat>::default();
631            let mut thres = K::InternalFloat::default();
632            let quality_ref = hash_get_threshold(
633                &mut ref_kernel,
634                GenericArray::from_slice(input_image.as_slice()).unflatten_square_ref(),
635                &mut thres,
636                &mut output_ref,
637                &mut buf1a,
638                &mut GenericArray::default(),
639                &mut buf2,
640            );
641            let mut distance = 0;
642            for (a, b) in output.iter().flatten().zip(output_ref.iter().flatten()) {
643                let bits_diff = (a ^ b).count_ones();
644                distance += bits_diff;
645            }
646
647            println!(
648                "[{} ({})] {}: Distance vs. ref32: {}/{} (Q={})",
649                std::any::type_name::<K>(),
650                std::any::type_name::<K::InternalFloat>(),
651                name,
652                distance,
653                16 * 16,
654                quality_ref
655            );
656        }
657    }
658
659    #[test]
660    fn test_hash_rethreshold() {
661        let input = image::load_from_memory(TEST_IMAGE_AAA_ORIG).unwrap();
662        let input = input.resize_exact(512, 512, image::imageops::FilterType::Triangle);
663        let input_image = input
664            .to_luma8()
665            .iter()
666            .map(|p| *p as f32)
667            .collect::<Vec<_>>();
668        let mut output = GenericArray::default();
669        let mut output_rethres = GenericArray::<GenericArray<u8, U2>, U16>::default();
670        let mut buf1 = Box::default();
671        let mut thres = 0.0f32;
672        let mut pdqf = GenericArray::<GenericArray<f32, U16>, U16>::default();
673
674        hash_get_threshold(
675            &mut kernel::DefaultKernelNoPadding::default(),
676            GenericArray::from_slice(input_image.as_slice()).unflatten_square_ref(),
677            &mut thres,
678            &mut output,
679            &mut buf1,
680            &mut GenericArray::default(),
681            &mut pdqf,
682        );
683
684        kernel::threshold::threshold_2d_f32::<U16>(&pdqf, &mut output_rethres, thres);
685
686        assert_eq!(output_rethres, output);
687    }
688
689    #[cfg(feature = "reference-rug")]
690    fn test_hash_impl_ref_arb<
691        ID: ArrayLength + SquareOf,
692        OD: ArrayLength + SquareOf + DivisibleBy8 + Mul<OD>,
693        K: Kernel<InputDimension = ID, OutputDimension = OD>,
694    >(
695        kernel: &mut K,
696    ) where
697        K::RequiredHardwareFeature: EvaluateHardwareFeature<EnabledStatic = B1>,
698        ReferenceKernel<crate::kernel::float128::ArbFloat<96>>: Kernel<
699                InputDimension = ID,
700                InternalFloat = crate::kernel::float128::ArbFloat<96>,
701                OutputDimension = OD,
702            >,
703        <OD as Mul<OD>>::Output: ArrayLength,
704    {
705        use generic_array::typenum::U127;
706
707        use crate::kernel::float128::ArbFloat;
708        let mut buf1 = Box::default();
709        let mut buf1a_uninit =
710            Box::<GenericArray<GenericArray<ArbFloat<96>, U127>, U127>>::new_uninit();
711        for i in 0..127 {
712            for j in 0..127 {
713                let value = ArbFloat::<96>::default();
714                unsafe {
715                    buf1a_uninit
716                        .as_mut()
717                        .assume_init_mut()
718                        .get_unchecked_mut(i)
719                        .as_mut_ptr()
720                        .add(j)
721                        .write(value);
722                }
723            }
724        }
725        let mut buf1a = unsafe { buf1a_uninit.assume_init() };
726        let mut buf2 = Box::default();
727        let mut buf2a = Box::default();
728
729        for (name, image_data) in [
730            ("aaa-orig.jpg", TEST_IMAGE_AAA_ORIG),
731            ("anime.png", TEST_IMAGE_ANIME),
732            ("music.png", TEST_IMAGE_MUSIC),
733            ("neofetch.png", TEST_IMAGE_NEOFETCH),
734        ] {
735            let input = image::load_from_memory(image_data).unwrap();
736
737            let input = input.resize_exact(512, 512, image::imageops::FilterType::Triangle);
738
739            let input_image = input
740                .to_luma8()
741                .iter()
742                .map(|p| *p as f32)
743                .collect::<Vec<_>>();
744
745            let mut output = GenericArray::default();
746            let mut output_ref = GenericArray::default();
747
748            hash(
749                kernel,
750                GenericArray::from_slice(input_image.as_slice()).unflatten_square_ref(),
751                &mut output,
752                &mut buf1,
753                &mut GenericArray::default(),
754                &mut buf2,
755            );
756            let mut ref_kernel = ReferenceKernel::<ArbFloat<96>>::default();
757            let quality_ref = hash(
758                &mut ref_kernel,
759                GenericArray::from_slice(input_image.as_slice()).unflatten_square_ref(),
760                &mut output_ref,
761                &mut buf1a,
762                &mut GenericArray::default(),
763                &mut buf2a,
764            );
765            let mut distance = 0;
766            for (a, b) in output.iter().flatten().zip(output_ref.iter().flatten()) {
767                let bits_diff = (a ^ b).count_ones();
768                distance += bits_diff;
769            }
770
771            println!(
772                "[{} ({})] {}: Distance vs. ref96: {}/{} (Q={})",
773                std::any::type_name::<K>(),
774                std::any::type_name::<K::InternalFloat>(),
775                name,
776                distance,
777                16 * 16,
778                quality_ref
779            );
780        }
781    }
782
783    #[test]
784    fn test_hash_impl_base() {
785        let mut kernel = kernel::DefaultKernelNoPadding::default();
786        test_hash_impl_lib(&mut kernel);
787        test_hash_impl_ref(&mut kernel);
788    }
789
790    #[test]
791    #[cfg(feature = "portable-simd")]
792    fn test_hash_impl_portable_simd() {
793        let mut kernel = kernel::portable_simd::PortableSimdF32Kernel::<8>;
794        test_hash_impl_lib(&mut kernel);
795        test_hash_impl_ref(&mut kernel);
796    }
797
798    #[cfg(feature = "reference-rug")]
799    #[test]
800    fn test_hash_impl_base_arb() {
801        let mut kernel = kernel::DefaultKernelNoPadding::default();
802        test_hash_impl_ref_arb(&mut kernel);
803    }
804
805    #[cfg(feature = "reference-rug")]
806    #[test]
807    fn test_hash_impl_ref_arb_rug() {
808        let mut kernel = kernel::ReferenceKernel::<f32>::default();
809        test_hash_impl_ref_arb(&mut kernel);
810    }
811
812    #[test]
813    #[cfg(all(
814        target_arch = "x86_64",
815        all(target_feature = "avx2", target_feature = "fma")
816    ))]
817    fn test_hash_impl_avx2() {
818        let mut kernel = kernel::x86::Avx2F32Kernel;
819        test_hash_impl_lib(&mut kernel);
820    }
821
822    #[cfg(all(
823        target_arch = "x86_64",
824        all(target_feature = "avx2", target_feature = "fma"),
825        feature = "reference-rug"
826    ))]
827    #[test]
828    fn test_hash_impl_avx2_arb() {
829        let mut kernel = kernel::x86::Avx2F32Kernel;
830        test_hash_impl_ref_arb(&mut kernel);
831    }
832
833    #[test]
834    #[cfg(all(target_arch = "x86_64", feature = "avx512", target_feature = "avx512f"))]
835    fn test_hash_impl_avx512() {
836        let mut kernel = kernel::x86::Avx512F32Kernel;
837        test_hash_impl_lib(&mut kernel);
838    }
839
840    #[cfg(all(
841        target_arch = "x86_64",
842        feature = "reference-rug",
843        target_feature = "avx512f",
844        feature = "avx512"
845    ))]
846    #[test]
847    fn test_hash_impl_avx512_arb() {
848        let mut kernel = kernel::x86::Avx512F32Kernel;
849        test_hash_impl_ref_arb(&mut kernel);
850    }
851}