1use singe_cuda::memory::DeviceMemory;
2use singe_npp_sys as sys;
3
4use crate::{
5 context::StreamContext,
6 error::Result,
7 image::view::{C1, C3, C4, ImageView},
8 try_ffi,
9 types::{DataTypeLike, Size},
10 utility::to_usize,
11};
12
13use super::statistics_validation::*;
14
15macro_rules! impl_image_pair_metric_batch {
16 (
17 $buffer_size_name:ident,
18 $name:ident,
19 $ty:ty,
20 $layout:ty,
21 $channels:expr,
22 $buffer_size_ffi:ident,
23 $ffi:ident
24 ) => {
25 pub fn $buffer_size_name(stream_context: &StreamContext, roi: Size) -> Result<usize> {
26 let mut bytes = 0;
27 unsafe {
28 try_ffi!(sys::$buffer_size_ffi(
29 roi.into(),
30 &raw mut bytes,
31 stream_context.as_raw(),
32 ))?;
33 }
34 to_usize(bytes, "buffer size")
35 }
36
37 pub fn $name(
38 stream_context: &StreamContext,
39 source_0: &[ImageView<'_, $ty, $layout>],
40 source_1: &[ImageView<'_, $ty, $layout>],
41 results: &mut DeviceMemory<f32>,
42 ) -> Result<()> {
43 let (batch_size, roi, descriptors_0, descriptors_1) =
44 pair_metric_batch_descriptors(source_0, source_1)?;
45 validate_metric_output(results, batch_size as usize, $channels)?;
46 let required_bytes = $buffer_size_name(stream_context, roi)?;
47 let (_scratch_buffers, mut scratch_descriptors) =
48 pair_metric_batch_buffers(batch_size as usize, required_bytes)?;
49
50 unsafe {
51 try_ffi!(sys::$ffi(
52 descriptors_0.as_ptr().cast(),
53 descriptors_1.as_ptr().cast(),
54 batch_size,
55 roi.into(),
56 results.as_mut_ptr().cast(),
57 scratch_descriptors.as_mut_ptr().cast(),
58 stream_context.as_raw(),
59 ))?;
60 }
61 Ok(())
62 }
63 };
64}
65
66macro_rules! impl_image_pair_metric_batch_advanced {
67 (
68 $buffer_size_name:ident,
69 $name:ident,
70 $ty:ty,
71 $layout:ty,
72 $channels:expr,
73 $ffi:ident
74 ) => {
75 pub fn $name(
76 stream_context: &StreamContext,
77 source_0: &[ImageView<'_, $ty, $layout>],
78 source_1: &[ImageView<'_, $ty, $layout>],
79 results: &mut DeviceMemory<f32>,
80 ) -> Result<()> {
81 let (batch_size, max_roi, descriptors_0, descriptors_1) =
82 pair_metric_batch_advanced_descriptors(source_0, source_1)?;
83 validate_metric_output(results, batch_size as usize, $channels)?;
84 let required_bytes = $buffer_size_name(stream_context, max_roi)?;
85 let (_scratch_buffers, mut scratch_descriptors) =
86 pair_metric_batch_buffers(batch_size as usize, required_bytes)?;
87
88 unsafe {
89 try_ffi!(sys::$ffi(
90 descriptors_0.as_ptr().cast(),
91 descriptors_1.as_ptr().cast(),
92 batch_size,
93 max_roi.into(),
94 results.as_mut_ptr().cast(),
95 scratch_descriptors.as_mut_ptr().cast(),
96 stream_context.as_raw(),
97 ))?;
98 }
99 Ok(())
100 }
101 };
102}
103
104macro_rules! impl_generic_image_pair_metric_batch {
105 ($trait_name:ident, $name:ident, $buffer_size_name:ident, $layout:ty, [$(($ty:ty, $direct:ident, $direct_buffer_size:ident)),+ $(,)?]) => {
106 pub trait $trait_name<Layout>: DataTypeLike {
107 fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize>;
108
109 fn dispatch(
110 stream_context: &StreamContext,
111 source_0: &[ImageView<'_, Self, Layout>],
112 source_1: &[ImageView<'_, Self, Layout>],
113 results: &mut DeviceMemory<f32>,
114 ) -> Result<()>
115 where
116 Self: Sized;
117 }
118
119 $(
120 impl $trait_name<$layout> for $ty {
121 fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize> {
122 $direct_buffer_size(stream_context, roi)
123 }
124
125 fn dispatch(
126 stream_context: &StreamContext,
127 source_0: &[ImageView<'_, Self, $layout>],
128 source_1: &[ImageView<'_, Self, $layout>],
129 results: &mut DeviceMemory<f32>,
130 ) -> Result<()> {
131 $direct(stream_context, source_0, source_1, results)
132 }
133 }
134 )+
135
136 pub fn $buffer_size_name<T: $trait_name<$layout>>(
137 stream_context: &StreamContext,
138 roi: Size,
139 ) -> Result<usize> {
140 T::buffer_size(stream_context, roi)
141 }
142
143 pub fn $name<T: $trait_name<$layout>>(
144 stream_context: &StreamContext,
145 source_0: &[ImageView<'_, T, $layout>],
146 source_1: &[ImageView<'_, T, $layout>],
147 results: &mut DeviceMemory<f32>,
148 ) -> Result<()> {
149 T::dispatch(stream_context, source_0, source_1, results)
150 }
151 };
152}
153
154impl_image_pair_metric_batch!(
155 mse_batch_u8_c1_buffer_size,
156 mse_batch_u8_c1,
157 u8,
158 C1,
159 1,
160 nppiMSEBatchGetBufferHostSize_8u_C1R_Ctx,
161 nppiMSEBatch_8u_C1R_Ctx
162);
163impl_image_pair_metric_batch!(
164 mse_batch_u8_c3_buffer_size,
165 mse_batch_u8_c3,
166 u8,
167 C3,
168 3,
169 nppiMSEBatchGetBufferHostSize_8u_C3R_Ctx,
170 nppiMSEBatch_8u_C3R_Ctx
171);
172impl_image_pair_metric_batch!(
173 psnr_batch_u8_c1_buffer_size,
174 psnr_batch_u8_c1,
175 u8,
176 C1,
177 1,
178 nppiPSNRBatchGetBufferHostSize_8u_C1R_Ctx,
179 nppiPSNRBatch_8u_C1R_Ctx
180);
181impl_image_pair_metric_batch!(
182 psnr_batch_u8_c3_buffer_size,
183 psnr_batch_u8_c3,
184 u8,
185 C3,
186 3,
187 nppiPSNRBatchGetBufferHostSize_8u_C3R_Ctx,
188 nppiPSNRBatch_8u_C3R_Ctx
189);
190impl_image_pair_metric_batch!(
191 ssim_batch_u8_c1_buffer_size,
192 ssim_batch_u8_c1,
193 u8,
194 C1,
195 1,
196 nppiSSIMBatchGetBufferHostSize_8u_C1R_Ctx,
197 nppiSSIMBatch_8u_C1R_Ctx
198);
199impl_image_pair_metric_batch!(
200 ssim_batch_u8_c3_buffer_size,
201 ssim_batch_u8_c3,
202 u8,
203 C3,
204 3,
205 nppiSSIMBatchGetBufferHostSize_8u_C3R_Ctx,
206 nppiSSIMBatch_8u_C3R_Ctx
207);
208impl_image_pair_metric_batch!(
209 wmsssim_batch_u8_c1_buffer_size,
210 wmsssim_batch_u8_c1,
211 u8,
212 C1,
213 1,
214 nppiWMSSSIMBatchGetBufferHostSize_8u_C1R_Ctx,
215 nppiWMSSSIMBatch_8u_C1R_Ctx
216);
217impl_image_pair_metric_batch!(
218 wmsssim_batch_u8_c3_buffer_size,
219 wmsssim_batch_u8_c3,
220 u8,
221 C3,
222 3,
223 nppiWMSSSIMBatchGetBufferHostSize_8u_C3R_Ctx,
224 nppiWMSSSIMBatch_8u_C3R_Ctx
225);
226impl_image_pair_metric_batch!(
227 wmsssim_batch_u8_c4_buffer_size,
228 wmsssim_batch_u8_c4,
229 u8,
230 C4,
231 4,
232 nppiWMSSSIMBatchGetBufferHostSize_8u_C4R_Ctx,
233 nppiWMSSSIMBatch_8u_C4R_Ctx
234);
235impl_image_pair_metric_batch_advanced!(
236 mse_batch_u8_c1_buffer_size,
237 mse_batch_u8_c1_advanced,
238 u8,
239 C1,
240 1,
241 nppiMSEBatch_8u_C1R_Advanced_Ctx
242);
243impl_image_pair_metric_batch_advanced!(
244 mse_batch_u8_c3_buffer_size,
245 mse_batch_u8_c3_advanced,
246 u8,
247 C3,
248 3,
249 nppiMSEBatch_8u_C3R_Advanced_Ctx
250);
251impl_image_pair_metric_batch_advanced!(
252 psnr_batch_u8_c1_buffer_size,
253 psnr_batch_u8_c1_advanced,
254 u8,
255 C1,
256 1,
257 nppiPSNRBatch_8u_C1R_Advanced_Ctx
258);
259impl_image_pair_metric_batch_advanced!(
260 psnr_batch_u8_c3_buffer_size,
261 psnr_batch_u8_c3_advanced,
262 u8,
263 C3,
264 3,
265 nppiPSNRBatch_8u_C3R_Advanced_Ctx
266);
267impl_image_pair_metric_batch_advanced!(
268 wmsssim_batch_u8_c1_buffer_size,
269 wmsssim_batch_u8_c1_advanced,
270 u8,
271 C1,
272 1,
273 nppiWMSSSIMBatch_8u_C1R_Advanced_Ctx
274);
275impl_image_pair_metric_batch_advanced!(
276 wmsssim_batch_u8_c3_buffer_size,
277 wmsssim_batch_u8_c3_advanced,
278 u8,
279 C3,
280 3,
281 nppiWMSSSIMBatch_8u_C3R_Advanced_Ctx
282);
283impl_image_pair_metric_batch_advanced!(
284 wmsssim_batch_u8_c4_buffer_size,
285 wmsssim_batch_u8_c4_advanced,
286 u8,
287 C4,
288 4,
289 nppiWMSSSIMBatch_8u_C4R_Advanced_Ctx
290);
291impl_generic_image_pair_metric_batch!(
292 MseBatchC1,
293 mse_batch_c1,
294 mse_batch_c1_buffer_size,
295 C1,
296 [(u8, mse_batch_u8_c1, mse_batch_u8_c1_buffer_size)]
297);
298impl_generic_image_pair_metric_batch!(
299 MseBatchC3,
300 mse_batch_c3,
301 mse_batch_c3_buffer_size,
302 C3,
303 [(u8, mse_batch_u8_c3, mse_batch_u8_c3_buffer_size)]
304);
305impl_generic_image_pair_metric_batch!(
306 PsnrBatchC1,
307 psnr_batch_c1,
308 psnr_batch_c1_buffer_size,
309 C1,
310 [(u8, psnr_batch_u8_c1, psnr_batch_u8_c1_buffer_size)]
311);
312impl_generic_image_pair_metric_batch!(
313 PsnrBatchC3,
314 psnr_batch_c3,
315 psnr_batch_c3_buffer_size,
316 C3,
317 [(u8, psnr_batch_u8_c3, psnr_batch_u8_c3_buffer_size)]
318);
319impl_generic_image_pair_metric_batch!(
320 SsimBatchC1,
321 ssim_batch_c1,
322 ssim_batch_c1_buffer_size,
323 C1,
324 [(u8, ssim_batch_u8_c1, ssim_batch_u8_c1_buffer_size)]
325);
326impl_generic_image_pair_metric_batch!(
327 SsimBatchC3,
328 ssim_batch_c3,
329 ssim_batch_c3_buffer_size,
330 C3,
331 [(u8, ssim_batch_u8_c3, ssim_batch_u8_c3_buffer_size)]
332);
333impl_generic_image_pair_metric_batch!(
334 WmsssimBatchC1,
335 wmsssim_batch_c1,
336 wmsssim_batch_c1_buffer_size,
337 C1,
338 [(u8, wmsssim_batch_u8_c1, wmsssim_batch_u8_c1_buffer_size)]
339);
340impl_generic_image_pair_metric_batch!(
341 WmsssimBatchC3,
342 wmsssim_batch_c3,
343 wmsssim_batch_c3_buffer_size,
344 C3,
345 [(u8, wmsssim_batch_u8_c3, wmsssim_batch_u8_c3_buffer_size)]
346);
347impl_generic_image_pair_metric_batch!(
348 WmsssimBatchC4,
349 wmsssim_batch_c4,
350 wmsssim_batch_c4_buffer_size,
351 C4,
352 [(u8, wmsssim_batch_u8_c4, wmsssim_batch_u8_c4_buffer_size)]
353);
354impl_generic_image_pair_metric_batch!(
355 MseBatchAdvancedC1,
356 mse_batch_advanced_c1,
357 mse_batch_advanced_c1_buffer_size,
358 C1,
359 [(u8, mse_batch_u8_c1_advanced, mse_batch_u8_c1_buffer_size)]
360);
361impl_generic_image_pair_metric_batch!(
362 MseBatchAdvancedC3,
363 mse_batch_advanced_c3,
364 mse_batch_advanced_c3_buffer_size,
365 C3,
366 [(u8, mse_batch_u8_c3_advanced, mse_batch_u8_c3_buffer_size)]
367);
368impl_generic_image_pair_metric_batch!(
369 PsnrBatchAdvancedC1,
370 psnr_batch_advanced_c1,
371 psnr_batch_advanced_c1_buffer_size,
372 C1,
373 [(u8, psnr_batch_u8_c1_advanced, psnr_batch_u8_c1_buffer_size)]
374);
375impl_generic_image_pair_metric_batch!(
376 PsnrBatchAdvancedC3,
377 psnr_batch_advanced_c3,
378 psnr_batch_advanced_c3_buffer_size,
379 C3,
380 [(u8, psnr_batch_u8_c3_advanced, psnr_batch_u8_c3_buffer_size)]
381);
382impl_generic_image_pair_metric_batch!(
383 WmsssimBatchAdvancedC1,
384 wmsssim_batch_advanced_c1,
385 wmsssim_batch_advanced_c1_buffer_size,
386 C1,
387 [(
388 u8,
389 wmsssim_batch_u8_c1_advanced,
390 wmsssim_batch_u8_c1_buffer_size
391 )]
392);
393impl_generic_image_pair_metric_batch!(
394 WmsssimBatchAdvancedC3,
395 wmsssim_batch_advanced_c3,
396 wmsssim_batch_advanced_c3_buffer_size,
397 C3,
398 [(
399 u8,
400 wmsssim_batch_u8_c3_advanced,
401 wmsssim_batch_u8_c3_buffer_size
402 )]
403);
404impl_generic_image_pair_metric_batch!(
405 WmsssimBatchAdvancedC4,
406 wmsssim_batch_advanced_c4,
407 wmsssim_batch_advanced_c4_buffer_size,
408 C4,
409 [(
410 u8,
411 wmsssim_batch_u8_c4_advanced,
412 wmsssim_batch_u8_c4_buffer_size
413 )]
414);