1use std::{marker::PhantomData, ptr::NonNull};
2
3use singe_cuda::{
4 memory::{DeviceMemory, MemoryCopyKind},
5 types::{Complex32, f16},
6};
7use singe_npp_sys as sys;
8
9use crate::{
10 error::{Error, Result},
11 image::view::{AC4, C1, C2, C3, C4, ChannelLayout, ImageView, ImageViewMut},
12 types::{ComplexI16, ComplexI32, DataTypeLike, Size},
13 utility::checked_len,
14};
15
16#[derive(Debug)]
24pub struct Image<T, L = C1> {
25 ptr: *mut T,
26 size: Size,
27 step: i32,
28 _t: PhantomData<L>,
29}
30
31pub trait SupportedImage<Layout>: DataTypeLike + private::ImageCreate<Layout>
33where
34 Layout: ChannelLayout,
35{
36}
37
38impl<T, Layout> SupportedImage<Layout> for T
39where
40 T: DataTypeLike + private::ImageCreate<Layout>,
41 Layout: ChannelLayout,
42{
43}
44
45mod private {
46 use super::*;
47
48 pub trait ImageCreate<Layout>: DataTypeLike + Sized
49 where
50 Layout: ChannelLayout,
51 {
52 fn create(size: Size) -> Result<Image<Self, Layout>>;
53 }
54}
55
56macro_rules! impl_image_create {
57 ($layout:ty, [$($ty:ty => $direct:ident),* $(,)?]) => {
58 $(
59 impl private::ImageCreate<$layout> for $ty {
60 fn create(size: Size) -> Result<Image<Self, $layout>> {
61 Image::<Self, $layout>::$direct(size)
62 }
63 }
64 )*
65 };
66}
67
68impl<T, L> Image<T, L>
69where
70 L: ChannelLayout,
71{
72 pub fn create(size: Size) -> Result<Self>
73 where
74 T: SupportedImage<L>,
75 {
76 <T as private::ImageCreate<L>>::create(size)
77 }
78
79 fn create_u8_c1(size: Size) -> Result<Image<u8, C1>> {
80 Image::create_with(size, sys::nppiMalloc_8u_C1)
81 }
82
83 fn create_u8_c2(size: Size) -> Result<Image<u8, C2>> {
84 Image::create_with(size, sys::nppiMalloc_8u_C2)
85 }
86
87 fn create_u8_c3(size: Size) -> Result<Image<u8, C3>> {
88 Image::create_with(size, sys::nppiMalloc_8u_C3)
89 }
90
91 fn create_u8_c4(size: Size) -> Result<Image<u8, C4>> {
92 Image::create_with(size, sys::nppiMalloc_8u_C4)
93 }
94
95 fn create_u8_ac4(size: Size) -> Result<Image<u8, AC4>> {
96 Image::create_with(size, sys::nppiMalloc_8u_C4)
97 }
98
99 fn create_i8_c1(size: Size) -> Result<Image<i8, C1>> {
100 Image::<i8, C1>::create_signed_8(size, sys::nppiMalloc_8u_C1)
101 }
102
103 fn create_i8_c2(size: Size) -> Result<Image<i8, C2>> {
104 Image::<i8, C2>::create_signed_8(size, sys::nppiMalloc_8u_C2)
105 }
106
107 fn create_i8_c3(size: Size) -> Result<Image<i8, C3>> {
108 Image::<i8, C3>::create_signed_8(size, sys::nppiMalloc_8u_C3)
109 }
110
111 fn create_i8_c4(size: Size) -> Result<Image<i8, C4>> {
112 Image::<i8, C4>::create_signed_8(size, sys::nppiMalloc_8u_C4)
113 }
114
115 fn create_i8_ac4(size: Size) -> Result<Image<i8, AC4>> {
116 Image::<i8, AC4>::create_signed_8(size, sys::nppiMalloc_8u_C4)
117 }
118
119 fn create_u16_c1(size: Size) -> Result<Image<u16, C1>> {
120 Image::create_with(size, sys::nppiMalloc_16u_C1)
121 }
122
123 fn create_u16_c2(size: Size) -> Result<Image<u16, C2>> {
124 Image::create_with(size, sys::nppiMalloc_16u_C2)
125 }
126
127 fn create_u16_c3(size: Size) -> Result<Image<u16, C3>> {
128 Image::create_with(size, sys::nppiMalloc_16u_C3)
129 }
130
131 fn create_u16_c4(size: Size) -> Result<Image<u16, C4>> {
132 Image::create_with(size, sys::nppiMalloc_16u_C4)
133 }
134
135 fn create_u16_ac4(size: Size) -> Result<Image<u16, AC4>> {
136 Image::create_with(size, sys::nppiMalloc_16u_C4)
137 }
138
139 fn create_f16_c1(size: Size) -> Result<Image<f16, C1>> {
140 Image::<f16, C1>::create_f16(size, sys::nppiMalloc_16u_C1)
141 }
142
143 fn create_f16_c2(size: Size) -> Result<Image<f16, C2>> {
144 Image::<f16, C2>::create_f16(size, sys::nppiMalloc_16u_C2)
145 }
146
147 fn create_f16_c3(size: Size) -> Result<Image<f16, C3>> {
148 Image::<f16, C3>::create_f16(size, sys::nppiMalloc_16u_C3)
149 }
150
151 fn create_f16_c4(size: Size) -> Result<Image<f16, C4>> {
152 Image::<f16, C4>::create_f16(size, sys::nppiMalloc_16u_C4)
153 }
154
155 fn create_f16_ac4(size: Size) -> Result<Image<f16, AC4>> {
156 Image::<f16, AC4>::create_f16(size, sys::nppiMalloc_16u_C4)
157 }
158
159 fn create_i16_c1(size: Size) -> Result<Image<i16, C1>> {
160 Image::create_with(size, sys::nppiMalloc_16s_C1)
161 }
162
163 fn create_i16_c2(size: Size) -> Result<Image<i16, C2>> {
164 Image::create_with(size, sys::nppiMalloc_16s_C2)
165 }
166
167 fn create_i16_c3(size: Size) -> Result<Image<i16, C3>> {
168 Image::<i16, C3>::create_signed_16(size, sys::nppiMalloc_16u_C3)
169 }
170
171 fn create_i16_c4(size: Size) -> Result<Image<i16, C4>> {
172 Image::create_with(size, sys::nppiMalloc_16s_C4)
173 }
174
175 fn create_i16_ac4(size: Size) -> Result<Image<i16, AC4>> {
176 Image::create_with(size, sys::nppiMalloc_16s_C4)
177 }
178
179 fn create_i16_complex_c1(size: Size) -> Result<Image<ComplexI16, C1>> {
180 Image::create_with(size, sys::nppiMalloc_16sc_C1)
181 }
182
183 fn create_i16_complex_c2(size: Size) -> Result<Image<ComplexI16, C2>> {
184 Image::create_with(size, sys::nppiMalloc_16sc_C2)
185 }
186
187 fn create_i16_complex_c3(size: Size) -> Result<Image<ComplexI16, C3>> {
188 Image::create_with(size, sys::nppiMalloc_16sc_C3)
189 }
190
191 fn create_i16_complex_c4(size: Size) -> Result<Image<ComplexI16, C4>> {
192 Image::create_with(size, sys::nppiMalloc_16sc_C4)
193 }
194
195 fn create_i16_complex_ac4(size: Size) -> Result<Image<ComplexI16, AC4>> {
196 Image::create_with(size, sys::nppiMalloc_16sc_C4)
197 }
198
199 fn create_i32_c1(size: Size) -> Result<Image<i32, C1>> {
200 Image::create_with(size, sys::nppiMalloc_32s_C1)
201 }
202
203 fn create_u32_c1(size: Size) -> Result<Image<u32, C1>> {
204 size.validate()?;
205 let mut step = 0;
206 let ptr = unsafe { sys::nppiMalloc_32s_C1(size.width, size.height, &raw mut step) };
207 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
208
209 Ok(Image {
210 ptr: ptr.as_ptr(),
211 size,
212 step,
213 _t: PhantomData,
214 })
215 }
216
217 fn create_i32_c3(size: Size) -> Result<Image<i32, C3>> {
218 Image::create_with(size, sys::nppiMalloc_32s_C3)
219 }
220
221 fn create_i32_c4(size: Size) -> Result<Image<i32, C4>> {
222 Image::create_with(size, sys::nppiMalloc_32s_C4)
223 }
224
225 fn create_i32_ac4(size: Size) -> Result<Image<i32, AC4>> {
226 Image::create_with(size, sys::nppiMalloc_32s_C4)
227 }
228
229 fn create_u32_ac4(size: Size) -> Result<Image<u32, AC4>> {
230 size.validate()?;
231 let mut step = 0;
232 let ptr = unsafe { sys::nppiMalloc_32s_C4(size.width, size.height, &raw mut step) };
233 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
234
235 Ok(Image {
236 ptr: ptr.as_ptr(),
237 size,
238 step,
239 _t: PhantomData,
240 })
241 }
242
243 fn create_i32_complex_c1(size: Size) -> Result<Image<ComplexI32, C1>> {
244 Image::create_with(size, sys::nppiMalloc_32sc_C1)
245 }
246
247 fn create_i32_complex_c2(size: Size) -> Result<Image<ComplexI32, C2>> {
248 Image::create_with(size, sys::nppiMalloc_32sc_C2)
249 }
250
251 fn create_i32_complex_c3(size: Size) -> Result<Image<ComplexI32, C3>> {
252 Image::create_with(size, sys::nppiMalloc_32sc_C3)
253 }
254
255 fn create_i32_complex_c4(size: Size) -> Result<Image<ComplexI32, C4>> {
256 Image::create_with(size, sys::nppiMalloc_32sc_C4)
257 }
258
259 fn create_i32_complex_ac4(size: Size) -> Result<Image<ComplexI32, AC4>> {
260 Image::create_with(size, sys::nppiMalloc_32sc_C4)
261 }
262
263 fn create_f32_c1(size: Size) -> Result<Image<f32, C1>> {
264 Image::create_with(size, sys::nppiMalloc_32f_C1)
265 }
266
267 fn create_f32_c2(size: Size) -> Result<Image<f32, C2>> {
268 Image::create_with(size, sys::nppiMalloc_32f_C2)
269 }
270
271 fn create_f32_c3(size: Size) -> Result<Image<f32, C3>> {
272 Image::create_with(size, sys::nppiMalloc_32f_C3)
273 }
274
275 fn create_f32_c4(size: Size) -> Result<Image<f32, C4>> {
276 Image::create_with(size, sys::nppiMalloc_32f_C4)
277 }
278
279 fn create_f32_ac4(size: Size) -> Result<Image<f32, AC4>> {
280 Image::create_with(size, sys::nppiMalloc_32f_C4)
281 }
282
283 fn create_f32_complex_c1(size: Size) -> Result<Image<Complex32, C1>> {
284 Image::create_with(size, sys::nppiMalloc_32fc_C1)
285 }
286
287 fn create_f32_complex_c2(size: Size) -> Result<Image<Complex32, C2>> {
288 Image::create_with(size, sys::nppiMalloc_32fc_C2)
289 }
290
291 fn create_f32_complex_c3(size: Size) -> Result<Image<Complex32, C3>> {
292 Image::create_with(size, sys::nppiMalloc_32fc_C3)
293 }
294
295 fn create_f32_complex_c4(size: Size) -> Result<Image<Complex32, C4>> {
296 Image::create_with(size, sys::nppiMalloc_32fc_C4)
297 }
298
299 fn create_f32_complex_ac4(size: Size) -> Result<Image<Complex32, AC4>> {
300 Image::create_with(size, sys::nppiMalloc_32fc_C4)
301 }
302
303 fn create_with<S>(
304 size: Size,
305 malloc: unsafe extern "C" fn(i32, i32, *mut i32) -> *mut S,
306 ) -> Result<Self> {
307 size.validate()?;
308 let mut step = 0;
309 let ptr = unsafe { malloc(size.width, size.height, &raw mut step) };
310 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
311
312 Ok(Self {
313 ptr: ptr.as_ptr(),
314 size,
315 step,
316 _t: PhantomData,
317 })
318 }
319
320 fn create_signed_8(
321 size: Size,
322 malloc: unsafe extern "C" fn(i32, i32, *mut i32) -> *mut u8,
323 ) -> Result<Image<i8, L>> {
324 size.validate()?;
325 let mut step = 0;
326 let ptr = unsafe { malloc(size.width, size.height, &raw mut step) };
327 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
328
329 Ok(Image {
330 ptr: ptr.as_ptr(),
331 size,
332 step,
333 _t: PhantomData,
334 })
335 }
336
337 fn create_signed_16(
338 size: Size,
339 malloc: unsafe extern "C" fn(i32, i32, *mut i32) -> *mut u16,
340 ) -> Result<Image<i16, L>> {
341 size.validate()?;
342 let mut step = 0;
343 let ptr = unsafe { malloc(size.width, size.height, &raw mut step) };
344 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
345
346 Ok(Image {
347 ptr: ptr.as_ptr(),
348 size,
349 step,
350 _t: PhantomData,
351 })
352 }
353
354 fn create_f16(
355 size: Size,
356 malloc: unsafe extern "C" fn(i32, i32, *mut i32) -> *mut u16,
357 ) -> Result<Image<f16, L>> {
358 size.validate()?;
359 let mut step = 0;
360 let ptr = unsafe { malloc(size.width, size.height, &raw mut step) };
361 let ptr = NonNull::new(ptr.cast()).ok_or(Error::NullHandle)?;
362
363 Ok(Image {
364 ptr: ptr.as_ptr(),
365 size,
366 step,
367 _t: PhantomData,
368 })
369 }
370
371 pub fn copy_to_device_memory(&self) -> Result<DeviceMemory<T>> {
372 let len = checked_len(self.size, L::CHANNELS)?;
373 let mut destination = DeviceMemory::create(len)?;
374 self.copy_into_device_memory(&mut destination)?;
375 Ok(destination)
376 }
377
378 pub fn copy_to_host_vec(&self) -> Result<Vec<T>> {
379 let len = checked_len(self.size, L::CHANNELS)?;
380 if len == 0 {
381 return Ok(Vec::new());
382 }
383
384 let mut host = Vec::<T>::with_capacity(len);
385 unsafe {
386 self.copy_rows_to(
387 host.as_mut_ptr().cast(),
388 row_bytes::<T, L>(self.size)?,
389 MemoryCopyKind::DeviceToHost,
390 )?;
391 host.set_len(len);
392 }
393 Ok(host)
394 }
395
396 pub fn copy_into_device_memory(&self, destination: &mut DeviceMemory<T>) -> Result<()> {
397 let len = checked_len(self.size, L::CHANNELS)?;
398 if destination.len() != len {
399 return Err(Error::LengthMismatch {
400 name: "image memory".into(),
401 expected: len,
402 actual: destination.len(),
403 });
404 }
405 if len == 0 {
406 return Ok(());
407 }
408
409 unsafe {
410 self.copy_rows_to(
411 destination.as_mut_ptr().cast(),
412 row_bytes::<T, L>(self.size)?,
413 MemoryCopyKind::DeviceToDevice,
414 )?;
415 }
416 Ok(())
417 }
418
419 pub const fn size(&self) -> Size {
421 self.size
422 }
423
424 pub const fn step(&self) -> i32 {
426 self.step
427 }
428
429 pub fn view(&self) -> Result<ImageView<'_, T, L>> {
436 unsafe { ImageView::from_raw_parts(self.ptr.cast(), self.size, self.step) }
437 }
438
439 pub fn view_mut(&mut self) -> Result<ImageViewMut<'_, T, L>> {
446 unsafe { ImageViewMut::from_raw_parts(self.ptr.cast(), self.size, self.step) }
447 }
448
449 pub const fn as_ptr(&self) -> *const T {
454 self.ptr as _
455 }
456
457 pub const fn as_mut_ptr(&mut self) -> *mut T {
462 self.ptr
463 }
464
465 unsafe fn copy_rows_to(
466 &self,
467 destination: *mut u8,
468 row_bytes: usize,
469 kind: MemoryCopyKind,
470 ) -> Result<()> {
471 let height = self.size.height as usize;
472 let source_step = self.step as usize;
473
474 if source_step == row_bytes {
475 unsafe {
476 DeviceMemory::<u8>::copy(
477 destination,
478 self.ptr.cast(),
479 row_bytes
480 .checked_mul(height)
481 .ok_or_else(|| Error::OutOfRange { name: "len".into() })?,
482 kind,
483 )?;
484 }
485 return Ok(());
486 }
487
488 for row in 0..height {
489 unsafe {
490 DeviceMemory::<u8>::copy(
491 destination.add(
492 row.checked_mul(row_bytes)
493 .ok_or_else(|| Error::OutOfRange { name: "len".into() })?,
494 ),
495 self.ptr.cast::<u8>().add(
496 row.checked_mul(source_step)
497 .ok_or_else(|| Error::OutOfRange { name: "len".into() })?,
498 ),
499 row_bytes,
500 kind,
501 )?;
502 }
503 }
504
505 Ok(())
506 }
507}
508
509impl_image_create!(C1, [
510 u8 => create_u8_c1,
511 i8 => create_i8_c1,
512 u16 => create_u16_c1,
513 f16 => create_f16_c1,
514 i16 => create_i16_c1,
515 ComplexI16 => create_i16_complex_c1,
516 i32 => create_i32_c1,
517 u32 => create_u32_c1,
518 ComplexI32 => create_i32_complex_c1,
519 f32 => create_f32_c1,
520 Complex32 => create_f32_complex_c1,
521]);
522impl_image_create!(C2, [
523 u8 => create_u8_c2,
524 i8 => create_i8_c2,
525 u16 => create_u16_c2,
526 f16 => create_f16_c2,
527 i16 => create_i16_c2,
528 ComplexI16 => create_i16_complex_c2,
529 ComplexI32 => create_i32_complex_c2,
530 f32 => create_f32_c2,
531 Complex32 => create_f32_complex_c2,
532]);
533impl_image_create!(C3, [
534 u8 => create_u8_c3,
535 i8 => create_i8_c3,
536 u16 => create_u16_c3,
537 f16 => create_f16_c3,
538 i16 => create_i16_c3,
539 ComplexI16 => create_i16_complex_c3,
540 i32 => create_i32_c3,
541 ComplexI32 => create_i32_complex_c3,
542 f32 => create_f32_c3,
543 Complex32 => create_f32_complex_c3,
544]);
545impl_image_create!(C4, [
546 u8 => create_u8_c4,
547 i8 => create_i8_c4,
548 u16 => create_u16_c4,
549 f16 => create_f16_c4,
550 i16 => create_i16_c4,
551 ComplexI16 => create_i16_complex_c4,
552 i32 => create_i32_c4,
553 ComplexI32 => create_i32_complex_c4,
554 f32 => create_f32_c4,
555 Complex32 => create_f32_complex_c4,
556]);
557impl_image_create!(AC4, [
558 u8 => create_u8_ac4,
559 i8 => create_i8_ac4,
560 u16 => create_u16_ac4,
561 f16 => create_f16_ac4,
562 i16 => create_i16_ac4,
563 ComplexI16 => create_i16_complex_ac4,
564 i32 => create_i32_ac4,
565 u32 => create_u32_ac4,
566 ComplexI32 => create_i32_complex_ac4,
567 f32 => create_f32_ac4,
568 Complex32 => create_f32_complex_ac4,
569]);
570
571pub fn create<T, L>(size: Size) -> Result<Image<T, L>>
572where
573 T: SupportedImage<L>,
574 L: ChannelLayout,
575{
576 Image::<T, L>::create(size)
577}
578
579impl<T, L> Drop for Image<T, L> {
580 fn drop(&mut self) {
581 unsafe {
582 if !self.ptr.is_null() {
583 sys::nppiFree(self.ptr.cast());
584 }
585 }
586 }
587}
588
589fn row_bytes<T, L>(size: Size) -> Result<usize>
590where
591 L: ChannelLayout,
592{
593 let element_size = size_of::<T>();
594 if element_size == 0 {
595 return Err(Error::OutOfRange {
596 name: "element size".into(),
597 });
598 }
599
600 (size.width as usize)
601 .checked_mul(L::CHANNELS)
602 .and_then(|value| value.checked_mul(element_size))
603 .ok_or_else(|| Error::OutOfRange { name: "len".into() })
604}