1use singe_cuda::memory::DeviceMemory;
2use singe_npp_sys as sys;
3
4use crate::{
5 context::StreamContext,
6 error::Result,
7 image::view::{C1, C3, C4, ImageView},
8 try_ffi,
9 types::{DataTypeLike, Size},
10 utility::to_usize,
11};
12
13use super::statistics_validation::*;
14
15macro_rules! impl_image_pair_metric {
16 (
17 $buffer_size_name:ident,
18 $name:ident,
19 $ty:ty,
20 $layout:ty,
21 $channels:expr,
22 $buffer_size_ffi:ident,
23 $ffi:ident
24 ) => {
25 pub fn $buffer_size_name(stream_context: &StreamContext, roi: Size) -> Result<usize> {
26 let mut bytes = 0;
27 unsafe {
28 try_ffi!(sys::$buffer_size_ffi(
29 roi.into(),
30 &raw mut bytes,
31 stream_context.as_raw(),
32 ))?;
33 }
34 to_usize(bytes, "buffer size")
35 }
36
37 pub fn $name(
38 stream_context: &StreamContext,
39 source_1: &ImageView<'_, $ty, $layout>,
40 source_2: &ImageView<'_, $ty, $layout>,
41 output: &mut DeviceMemory<f32>,
42 ) -> Result<()> {
43 validate_same_size(source_1.size(), source_2.size())?;
44 validate_metric_output(output, 1, $channels)?;
45 let required_bytes = $buffer_size_name(stream_context, source_1.size())?;
46 let scratch = DeviceMemory::<u8>::create(required_bytes)?;
47
48 unsafe {
49 try_ffi!(sys::$ffi(
50 source_1.as_ptr().cast(),
51 source_1.step(),
52 source_2.as_ptr().cast(),
53 source_2.step(),
54 source_1.size().into(),
55 output.as_mut_ptr().cast(),
56 scratch.as_mut_ptr().cast(),
57 stream_context.as_raw(),
58 ))?;
59 }
60 Ok(())
61 }
62 };
63}
64
65macro_rules! impl_generic_image_pair_metric {
66 ($trait_name:ident, $name:ident, $buffer_size_name:ident, $layout:ty, [$(($ty:ty, $direct:ident, $direct_buffer_size:ident)),+ $(,)?]) => {
67 pub trait $trait_name<Layout>: DataTypeLike {
68 fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize>;
69
70 fn dispatch(
71 stream_context: &StreamContext,
72 source_1: &ImageView<'_, Self, Layout>,
73 source_2: &ImageView<'_, Self, Layout>,
74 output: &mut DeviceMemory<f32>,
75 ) -> Result<()>
76 where
77 Self: Sized;
78 }
79
80 $(
81 impl $trait_name<$layout> for $ty {
82 fn buffer_size(stream_context: &StreamContext, roi: Size) -> Result<usize> {
83 $direct_buffer_size(stream_context, roi)
84 }
85
86 fn dispatch(
87 stream_context: &StreamContext,
88 source_1: &ImageView<'_, Self, $layout>,
89 source_2: &ImageView<'_, Self, $layout>,
90 output: &mut DeviceMemory<f32>,
91 ) -> Result<()> {
92 $direct(stream_context, source_1, source_2, output)
93 }
94 }
95 )+
96
97 pub fn $buffer_size_name<T: $trait_name<$layout>>(
98 stream_context: &StreamContext,
99 roi: Size,
100 ) -> Result<usize> {
101 T::buffer_size(stream_context, roi)
102 }
103
104 pub fn $name<T: $trait_name<$layout>>(
105 stream_context: &StreamContext,
106 source_1: &ImageView<'_, T, $layout>,
107 source_2: &ImageView<'_, T, $layout>,
108 output: &mut DeviceMemory<f32>,
109 ) -> Result<()> {
110 T::dispatch(stream_context, source_1, source_2, output)
111 }
112 };
113}
114
115impl_image_pair_metric!(
116 mse_u8_c1_buffer_size,
117 mse_u8_c1,
118 u8,
119 C1,
120 1,
121 nppiMSEGetBufferHostSize_8u_C1R_Ctx,
122 nppiMSE_8u_C1R_Ctx
123);
124impl_image_pair_metric!(
125 mse_u8_c3_buffer_size,
126 mse_u8_c3,
127 u8,
128 C3,
129 3,
130 nppiMSEGetBufferHostSize_8u_C3R_Ctx,
131 nppiMSE_8u_C3R_Ctx
132);
133impl_image_pair_metric!(
134 psnr_u8_c1_buffer_size,
135 psnr_u8_c1,
136 u8,
137 C1,
138 1,
139 nppiPSNRGetBufferHostSize_8u_C1R_Ctx,
140 nppiPSNR_8u_C1R_Ctx
141);
142impl_image_pair_metric!(
143 psnr_u8_c3_buffer_size,
144 psnr_u8_c3,
145 u8,
146 C3,
147 3,
148 nppiPSNRGetBufferHostSize_8u_C3R_Ctx,
149 nppiPSNR_8u_C3R_Ctx
150);
151impl_image_pair_metric!(
152 ssim_u8_c1_buffer_size,
153 ssim_u8_c1,
154 u8,
155 C1,
156 1,
157 nppiSSIMGetBufferHostSize_8u_C1R_Ctx,
158 nppiSSIM_8u_C1R_Ctx
159);
160impl_image_pair_metric!(
161 ssim_u8_c3_buffer_size,
162 ssim_u8_c3,
163 u8,
164 C3,
165 3,
166 nppiSSIMGetBufferHostSize_8u_C3R_Ctx,
167 nppiSSIM_8u_C3R_Ctx
168);
169impl_image_pair_metric!(
170 msssim_u8_c1_buffer_size,
171 msssim_u8_c1,
172 u8,
173 C1,
174 1,
175 nppiMSSSIMGetBufferHostSize_8u_C1R_Ctx,
176 nppiMSSSIM_8u_C1R_Ctx
177);
178impl_image_pair_metric!(
179 wmsssim_u8_c1_buffer_size,
180 wmsssim_u8_c1,
181 u8,
182 C1,
183 1,
184 nppiWMSSSIMGetBufferHostSize_8u_C1R_Ctx,
185 nppiWMSSSIM_8u_C1R_Ctx
186);
187impl_image_pair_metric!(
188 wmsssim_u8_c3_buffer_size,
189 wmsssim_u8_c3,
190 u8,
191 C3,
192 3,
193 nppiWMSSSIMGetBufferHostSize_8u_C3R_Ctx,
194 nppiWMSSSIM_8u_C3R_Ctx
195);
196impl_image_pair_metric!(
197 wmsssim_u8_c4_buffer_size,
198 wmsssim_u8_c4,
199 u8,
200 C4,
201 4,
202 nppiWMSSSIMGetBufferHostSize_8u_C4R_Ctx,
203 nppiWMSSSIM_8u_C4R_Ctx
204);
205impl_generic_image_pair_metric!(
206 MseC1,
207 mse_c1,
208 mse_c1_buffer_size,
209 C1,
210 [(u8, mse_u8_c1, mse_u8_c1_buffer_size)]
211);
212impl_generic_image_pair_metric!(
213 MseC3,
214 mse_c3,
215 mse_c3_buffer_size,
216 C3,
217 [(u8, mse_u8_c3, mse_u8_c3_buffer_size)]
218);
219impl_generic_image_pair_metric!(
220 PsnrC1,
221 psnr_c1,
222 psnr_c1_buffer_size,
223 C1,
224 [(u8, psnr_u8_c1, psnr_u8_c1_buffer_size)]
225);
226impl_generic_image_pair_metric!(
227 PsnrC3,
228 psnr_c3,
229 psnr_c3_buffer_size,
230 C3,
231 [(u8, psnr_u8_c3, psnr_u8_c3_buffer_size)]
232);
233impl_generic_image_pair_metric!(
234 SsimC1,
235 ssim_c1,
236 ssim_c1_buffer_size,
237 C1,
238 [(u8, ssim_u8_c1, ssim_u8_c1_buffer_size)]
239);
240impl_generic_image_pair_metric!(
241 SsimC3,
242 ssim_c3,
243 ssim_c3_buffer_size,
244 C3,
245 [(u8, ssim_u8_c3, ssim_u8_c3_buffer_size)]
246);
247impl_generic_image_pair_metric!(
248 MsssimC1,
249 msssim_c1,
250 msssim_c1_buffer_size,
251 C1,
252 [(u8, msssim_u8_c1, msssim_u8_c1_buffer_size)]
253);
254impl_generic_image_pair_metric!(
255 WmsssimC1,
256 wmsssim_c1,
257 wmsssim_c1_buffer_size,
258 C1,
259 [(u8, wmsssim_u8_c1, wmsssim_u8_c1_buffer_size)]
260);
261impl_generic_image_pair_metric!(
262 WmsssimC3,
263 wmsssim_c3,
264 wmsssim_c3_buffer_size,
265 C3,
266 [(u8, wmsssim_u8_c3, wmsssim_u8_c3_buffer_size)]
267);
268impl_generic_image_pair_metric!(
269 WmsssimC4,
270 wmsssim_c4,
271 wmsssim_c4_buffer_size,
272 C4,
273 [(u8, wmsssim_u8_c4, wmsssim_u8_c4_buffer_size)]
274);