Skip to main content

singe_npp/image/
statistics_batch_metrics.rs

1use singe_cuda::memory::DeviceMemory;
2use singe_npp_sys as sys;
3
4use crate::{
5    context::StreamContext,
6    error::Result,
7    image::view::{C1, C3, C4, ImageView},
8    try_ffi,
9    types::{DataTypeLike, Size},
10    utility::to_usize,
11};
12
13use super::statistics_validation::*;
14
15macro_rules! impl_image_pair_metric_batch {
16    (
17        $buffer_size_name:ident,
18        $name:ident,
19        $ty:ty,
20        $layout:ty,
21        $channels:expr,
22        $buffer_size_ffi:ident,
23        $ffi:ident
24    ) => {
25        pub fn $buffer_size_name(stream_context: &StreamContext, roi: Size) -> Result<usize> {
26            let mut bytes = 0;
27            unsafe {
28                try_ffi!(sys::$buffer_size_ffi(
29                    roi.into(),
30                    &raw mut bytes,
31                    stream_context.as_raw(),
32                ))?;
33            }
34            to_usize(bytes, "buffer size")
35        }
36
37        pub fn $name(
38            stream_context: &StreamContext,
39            source_0: &[ImageView<'_, $ty, $layout>],
40            source_1: &[ImageView<'_, $ty, $layout>],
41            results: &mut DeviceMemory<f32>,
42        ) -> Result<()> {
43            let (batch_size, roi, descriptors_0, descriptors_1) =
44                pair_metric_batch_descriptors(source_0, source_1)?;
45            validate_metric_output(results, batch_size as usize, $channels)?;
46            let required_bytes = $buffer_size_name(stream_context, roi)?;
47            let (_scratch_buffers, mut scratch_descriptors) =
48                pair_metric_batch_buffers(batch_size as usize, required_bytes)?;
49
50            unsafe {
51                try_ffi!(sys::$ffi(
52                    descriptors_0.as_ptr().cast(),
53                    descriptors_1.as_ptr().cast(),
54                    batch_size,
55                    roi.into(),
56                    results.as_mut_ptr().cast(),
57                    scratch_descriptors.as_mut_ptr().cast(),
58                    stream_context.as_raw(),
59                ))?;
60            }
61            Ok(())
62        }
63    };
64}
65
66macro_rules! impl_image_pair_metric_batch_advanced {
67    (
68        $buffer_size_name:ident,
69        $name:ident,
70        $ty:ty,
71        $layout:ty,
72        $channels:expr,
73        $ffi:ident
74    ) => {
75        pub fn $name(
76            stream_context: &StreamContext,
77            source_0: &[ImageView<'_, $ty, $layout>],
78            source_1: &[ImageView<'_, $ty, $layout>],
79            results: &mut DeviceMemory<f32>,
80        ) -> Result<()> {
81            let (batch_size, max_roi, descriptors_0, descriptors_1) =
82                pair_metric_batch_advanced_descriptors(source_0, source_1)?;
83            validate_metric_output(results, batch_size as usize, $channels)?;
84            let required_bytes = $buffer_size_name(stream_context, max_roi)?;
85            let (_scratch_buffers, mut scratch_descriptors) =
86                pair_metric_batch_buffers(batch_size as usize, required_bytes)?;
87
88            unsafe {
89                try_ffi!(sys::$ffi(
90                    descriptors_0.as_ptr().cast(),
91                    descriptors_1.as_ptr().cast(),
92                    batch_size,
93                    max_roi.into(),
94                    results.as_mut_ptr().cast(),
95                    scratch_descriptors.as_mut_ptr().cast(),
96                    stream_context.as_raw(),
97                ))?;
98            }
99            Ok(())
100        }
101    };
102}
103
104macro_rules! impl_generic_image_pair_metric_batch {
105    ($trait_name:ident, $name:ident, $buffer_size_name:ident, $layout:ty, [$(($ty:ty, $direct:ident, $direct_buffer_size:ident)),+ $(,)?]) => {
106        pub trait $trait_name<Layout>: DataTypeLike {
107            fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize>;
108
109            fn dispatch(
110                stream_context: &StreamContext,
111                source_0: &[ImageView<'_, Self, Layout>],
112                source_1: &[ImageView<'_, Self, Layout>],
113                results: &mut DeviceMemory<f32>,
114            ) -> Result<()>
115            where
116                Self: Sized;
117        }
118
119        $(
120            impl $trait_name<$layout> for $ty {
121                fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize> {
122                    $direct_buffer_size(stream_context, roi)
123                }
124
125                fn dispatch(
126                    stream_context: &StreamContext,
127                    source_0: &[ImageView<'_, Self, $layout>],
128                    source_1: &[ImageView<'_, Self, $layout>],
129                    results: &mut DeviceMemory<f32>,
130                ) -> Result<()> {
131                    $direct(stream_context, source_0, source_1, results)
132                }
133            }
134        )+
135
136        pub fn $buffer_size_name<T: $trait_name<$layout>>(
137            stream_context: &StreamContext,
138            roi: Size,
139        ) -> Result<usize> {
140            T::buffer_size(stream_context, roi)
141        }
142
143        pub fn $name<T: $trait_name<$layout>>(
144            stream_context: &StreamContext,
145            source_0: &[ImageView<'_, T, $layout>],
146            source_1: &[ImageView<'_, T, $layout>],
147            results: &mut DeviceMemory<f32>,
148        ) -> Result<()> {
149            T::dispatch(stream_context, source_0, source_1, results)
150        }
151    };
152}
153
154impl_image_pair_metric_batch!(
155    mse_batch_u8_c1_buffer_size,
156    mse_batch_u8_c1,
157    u8,
158    C1,
159    1,
160    nppiMSEBatchGetBufferHostSize_8u_C1R_Ctx,
161    nppiMSEBatch_8u_C1R_Ctx
162);
163impl_image_pair_metric_batch!(
164    mse_batch_u8_c3_buffer_size,
165    mse_batch_u8_c3,
166    u8,
167    C3,
168    3,
169    nppiMSEBatchGetBufferHostSize_8u_C3R_Ctx,
170    nppiMSEBatch_8u_C3R_Ctx
171);
172impl_image_pair_metric_batch!(
173    psnr_batch_u8_c1_buffer_size,
174    psnr_batch_u8_c1,
175    u8,
176    C1,
177    1,
178    nppiPSNRBatchGetBufferHostSize_8u_C1R_Ctx,
179    nppiPSNRBatch_8u_C1R_Ctx
180);
181impl_image_pair_metric_batch!(
182    psnr_batch_u8_c3_buffer_size,
183    psnr_batch_u8_c3,
184    u8,
185    C3,
186    3,
187    nppiPSNRBatchGetBufferHostSize_8u_C3R_Ctx,
188    nppiPSNRBatch_8u_C3R_Ctx
189);
190impl_image_pair_metric_batch!(
191    ssim_batch_u8_c1_buffer_size,
192    ssim_batch_u8_c1,
193    u8,
194    C1,
195    1,
196    nppiSSIMBatchGetBufferHostSize_8u_C1R_Ctx,
197    nppiSSIMBatch_8u_C1R_Ctx
198);
199impl_image_pair_metric_batch!(
200    ssim_batch_u8_c3_buffer_size,
201    ssim_batch_u8_c3,
202    u8,
203    C3,
204    3,
205    nppiSSIMBatchGetBufferHostSize_8u_C3R_Ctx,
206    nppiSSIMBatch_8u_C3R_Ctx
207);
208impl_image_pair_metric_batch!(
209    wmsssim_batch_u8_c1_buffer_size,
210    wmsssim_batch_u8_c1,
211    u8,
212    C1,
213    1,
214    nppiWMSSSIMBatchGetBufferHostSize_8u_C1R_Ctx,
215    nppiWMSSSIMBatch_8u_C1R_Ctx
216);
217impl_image_pair_metric_batch!(
218    wmsssim_batch_u8_c3_buffer_size,
219    wmsssim_batch_u8_c3,
220    u8,
221    C3,
222    3,
223    nppiWMSSSIMBatchGetBufferHostSize_8u_C3R_Ctx,
224    nppiWMSSSIMBatch_8u_C3R_Ctx
225);
226impl_image_pair_metric_batch!(
227    wmsssim_batch_u8_c4_buffer_size,
228    wmsssim_batch_u8_c4,
229    u8,
230    C4,
231    4,
232    nppiWMSSSIMBatchGetBufferHostSize_8u_C4R_Ctx,
233    nppiWMSSSIMBatch_8u_C4R_Ctx
234);
235impl_image_pair_metric_batch_advanced!(
236    mse_batch_u8_c1_buffer_size,
237    mse_batch_u8_c1_advanced,
238    u8,
239    C1,
240    1,
241    nppiMSEBatch_8u_C1R_Advanced_Ctx
242);
243impl_image_pair_metric_batch_advanced!(
244    mse_batch_u8_c3_buffer_size,
245    mse_batch_u8_c3_advanced,
246    u8,
247    C3,
248    3,
249    nppiMSEBatch_8u_C3R_Advanced_Ctx
250);
251impl_image_pair_metric_batch_advanced!(
252    psnr_batch_u8_c1_buffer_size,
253    psnr_batch_u8_c1_advanced,
254    u8,
255    C1,
256    1,
257    nppiPSNRBatch_8u_C1R_Advanced_Ctx
258);
259impl_image_pair_metric_batch_advanced!(
260    psnr_batch_u8_c3_buffer_size,
261    psnr_batch_u8_c3_advanced,
262    u8,
263    C3,
264    3,
265    nppiPSNRBatch_8u_C3R_Advanced_Ctx
266);
267impl_image_pair_metric_batch_advanced!(
268    wmsssim_batch_u8_c1_buffer_size,
269    wmsssim_batch_u8_c1_advanced,
270    u8,
271    C1,
272    1,
273    nppiWMSSSIMBatch_8u_C1R_Advanced_Ctx
274);
275impl_image_pair_metric_batch_advanced!(
276    wmsssim_batch_u8_c3_buffer_size,
277    wmsssim_batch_u8_c3_advanced,
278    u8,
279    C3,
280    3,
281    nppiWMSSSIMBatch_8u_C3R_Advanced_Ctx
282);
283impl_image_pair_metric_batch_advanced!(
284    wmsssim_batch_u8_c4_buffer_size,
285    wmsssim_batch_u8_c4_advanced,
286    u8,
287    C4,
288    4,
289    nppiWMSSSIMBatch_8u_C4R_Advanced_Ctx
290);
291impl_generic_image_pair_metric_batch!(
292    MseBatchC1,
293    mse_batch_c1,
294    mse_batch_c1_buffer_size,
295    C1,
296    [(u8, mse_batch_u8_c1, mse_batch_u8_c1_buffer_size)]
297);
298impl_generic_image_pair_metric_batch!(
299    MseBatchC3,
300    mse_batch_c3,
301    mse_batch_c3_buffer_size,
302    C3,
303    [(u8, mse_batch_u8_c3, mse_batch_u8_c3_buffer_size)]
304);
305impl_generic_image_pair_metric_batch!(
306    PsnrBatchC1,
307    psnr_batch_c1,
308    psnr_batch_c1_buffer_size,
309    C1,
310    [(u8, psnr_batch_u8_c1, psnr_batch_u8_c1_buffer_size)]
311);
312impl_generic_image_pair_metric_batch!(
313    PsnrBatchC3,
314    psnr_batch_c3,
315    psnr_batch_c3_buffer_size,
316    C3,
317    [(u8, psnr_batch_u8_c3, psnr_batch_u8_c3_buffer_size)]
318);
319impl_generic_image_pair_metric_batch!(
320    SsimBatchC1,
321    ssim_batch_c1,
322    ssim_batch_c1_buffer_size,
323    C1,
324    [(u8, ssim_batch_u8_c1, ssim_batch_u8_c1_buffer_size)]
325);
326impl_generic_image_pair_metric_batch!(
327    SsimBatchC3,
328    ssim_batch_c3,
329    ssim_batch_c3_buffer_size,
330    C3,
331    [(u8, ssim_batch_u8_c3, ssim_batch_u8_c3_buffer_size)]
332);
333impl_generic_image_pair_metric_batch!(
334    WmsssimBatchC1,
335    wmsssim_batch_c1,
336    wmsssim_batch_c1_buffer_size,
337    C1,
338    [(u8, wmsssim_batch_u8_c1, wmsssim_batch_u8_c1_buffer_size)]
339);
340impl_generic_image_pair_metric_batch!(
341    WmsssimBatchC3,
342    wmsssim_batch_c3,
343    wmsssim_batch_c3_buffer_size,
344    C3,
345    [(u8, wmsssim_batch_u8_c3, wmsssim_batch_u8_c3_buffer_size)]
346);
347impl_generic_image_pair_metric_batch!(
348    WmsssimBatchC4,
349    wmsssim_batch_c4,
350    wmsssim_batch_c4_buffer_size,
351    C4,
352    [(u8, wmsssim_batch_u8_c4, wmsssim_batch_u8_c4_buffer_size)]
353);
354impl_generic_image_pair_metric_batch!(
355    MseBatchAdvancedC1,
356    mse_batch_advanced_c1,
357    mse_batch_advanced_c1_buffer_size,
358    C1,
359    [(u8, mse_batch_u8_c1_advanced, mse_batch_u8_c1_buffer_size)]
360);
361impl_generic_image_pair_metric_batch!(
362    MseBatchAdvancedC3,
363    mse_batch_advanced_c3,
364    mse_batch_advanced_c3_buffer_size,
365    C3,
366    [(u8, mse_batch_u8_c3_advanced, mse_batch_u8_c3_buffer_size)]
367);
368impl_generic_image_pair_metric_batch!(
369    PsnrBatchAdvancedC1,
370    psnr_batch_advanced_c1,
371    psnr_batch_advanced_c1_buffer_size,
372    C1,
373    [(u8, psnr_batch_u8_c1_advanced, psnr_batch_u8_c1_buffer_size)]
374);
375impl_generic_image_pair_metric_batch!(
376    PsnrBatchAdvancedC3,
377    psnr_batch_advanced_c3,
378    psnr_batch_advanced_c3_buffer_size,
379    C3,
380    [(u8, psnr_batch_u8_c3_advanced, psnr_batch_u8_c3_buffer_size)]
381);
382impl_generic_image_pair_metric_batch!(
383    WmsssimBatchAdvancedC1,
384    wmsssim_batch_advanced_c1,
385    wmsssim_batch_advanced_c1_buffer_size,
386    C1,
387    [(
388        u8,
389        wmsssim_batch_u8_c1_advanced,
390        wmsssim_batch_u8_c1_buffer_size
391    )]
392);
393impl_generic_image_pair_metric_batch!(
394    WmsssimBatchAdvancedC3,
395    wmsssim_batch_advanced_c3,
396    wmsssim_batch_advanced_c3_buffer_size,
397    C3,
398    [(
399        u8,
400        wmsssim_batch_u8_c3_advanced,
401        wmsssim_batch_u8_c3_buffer_size
402    )]
403);
404impl_generic_image_pair_metric_batch!(
405    WmsssimBatchAdvancedC4,
406    wmsssim_batch_advanced_c4,
407    wmsssim_batch_advanced_c4_buffer_size,
408    C4,
409    [(
410        u8,
411        wmsssim_batch_u8_c4_advanced,
412        wmsssim_batch_u8_c4_buffer_size
413    )]
414);