Skip to main content

singe_npp/
workspace.rs

1//! Scratch-buffer support for NPP wrappers.
2//!
3//! `ScratchBuffer` owns temporary device memory used by NPP operations that
4//! expose host-side buffer-size queries.
5
6use singe_cuda::memory::DeviceMemory;
7use singe_npp_sys as sys;
8
9use crate::{
10    error::{Error, Result},
11    types::BufferDescriptor,
12    utility::to_i32,
13};
14
15#[derive(Debug)]
16pub struct ScratchBuffer {
17    memory: DeviceMemory<u8>,
18}
19
20impl ScratchBuffer {
21    pub fn create(bytes: usize) -> Result<Self> {
22        Ok(Self {
23            memory: DeviceMemory::create(bytes)?,
24        })
25    }
26
27    pub fn from_memory(memory: DeviceMemory<u8>) -> Self {
28        Self { memory }
29    }
30
31    pub const fn len(&self) -> usize {
32        self.memory.len()
33    }
34
35    pub const fn is_empty(&self) -> bool {
36        self.memory.is_empty()
37    }
38
39    pub fn require(&self, bytes: usize) -> Result<()> {
40        if self.len() < bytes {
41            return Err(Error::LengthMismatch {
42                name: "scratch buffer".into(),
43                expected: bytes,
44                actual: self.len(),
45            });
46        }
47
48        Ok(())
49    }
50
51    pub fn as_mut_ptr(&mut self) -> *mut u8 {
52        self.memory.as_mut_ptr()
53    }
54
55    pub fn into_device_memory(self) -> DeviceMemory<u8> {
56        self.memory
57    }
58}
59
60#[derive(Debug)]
61pub(crate) struct BufferDescriptors {
62    memory: DeviceMemory<sys::NppiBufferDescriptor>,
63}
64
65impl BufferDescriptors {
66    pub fn as_mut_ptr(&mut self) -> *mut sys::NppiBufferDescriptor {
67        self.memory.as_mut_ptr()
68    }
69}
70
71pub(crate) fn create_buffer_descriptors(
72    buffer_sizes: impl IntoIterator<Item = usize>,
73) -> Result<(Vec<DeviceMemory<u8>>, BufferDescriptors)> {
74    let buffer_sizes: Vec<_> = buffer_sizes.into_iter().collect();
75    let mut buffers = Vec::with_capacity(buffer_sizes.len());
76    let mut descriptors = Vec::with_capacity(buffer_sizes.len());
77
78    for bytes in buffer_sizes {
79        let descriptor_size = to_i32(bytes, "buffer size")?;
80        let buffer = DeviceMemory::<u8>::create(bytes)?;
81        descriptors.push(BufferDescriptor {
82            data: buffer.as_mut_ptr().cast(),
83            size: descriptor_size,
84        });
85        buffers.push(buffer);
86    }
87
88    let raw_descriptors = descriptors
89        .into_iter()
90        .map(sys::NppiBufferDescriptor::from)
91        .collect::<Vec<_>>();
92
93    Ok((
94        buffers,
95        BufferDescriptors {
96            memory: DeviceMemory::from_slice(&raw_descriptors)?,
97        },
98    ))
99}
100
101pub(crate) fn create_repeated_buffer_descriptors(
102    count: usize,
103    bytes_per_buffer: usize,
104) -> Result<(Vec<DeviceMemory<u8>>, BufferDescriptors)> {
105    create_buffer_descriptors(std::iter::repeat_n(bytes_per_buffer, count))
106}