tensor_compute/tensors/gpu_tensor/
traits.rs

1use crate::gpu_internals::gpu_buffers::GpuBuffer;
2use crate::gpu_internals::shader_runner::{BufferType, ShaderInput};
3use crate::gpu_internals::GpuInstance;
4use crate::{CpuTensor, ShapeStrideTrait};
5use async_trait::async_trait;
6use zerocopy::AsBytes;
7
8#[async_trait(?Send)]
9pub trait GpuAllocated {
10    fn get_gpu(&self) -> &'static GpuInstance;
11    fn internal_gpu_buffer(&self) -> &GpuBuffer;
12    fn internal_buffer_size_in_bytes(&self) -> usize {
13        self.internal_gpu_buffer().size_bytes()
14    }
15}
16
17#[async_trait(?Send)]
18pub trait CpuTransferable {
19    async fn to_cpu(&self) -> CpuTensor;
20}
21
22#[async_trait(?Send)]
23impl<T> CpuTransferable for T
24where
25    T: GpuAllocated + ShapeStrideTrait,
26{
27    async fn to_cpu(&self) -> CpuTensor {
28        let gpu = self.get_gpu();
29        let buffer_in_cpu_mem = gpu.copy_buffer_to_cpu_mem(self.internal_gpu_buffer()).await;
30        CpuTensor::new_with_strides_and_offset(
31            buffer_in_cpu_mem,
32            self.shape().clone(),
33            self.strides().clone(),
34            self.offset(),
35        )
36    }
37}
38
39pub trait AsShaderInput: GpuAllocated + ShapeStrideTrait {
40    fn to_shader_inputs(&self, binding_offset: usize) -> Vec<ShaderInput> {
41        let mut shape: Vec<u128> = self.shape().iter().map(|&e| e as u128).collect();
42        let mut strides: Vec<u128> = self.strides().iter().map(|&e| e as u128).collect();
43        // Uniform Buffer elements need to be 128bits each:
44        // see https://www.khronos.org/registry/OpenGL/specs/gl/glspec46.core.pdf page 146 (pdf page 168)
45        assert!(shape.len() <= 20, "Shape cant have more than 20 elements");
46        assert!(
47            strides.len() <= 20,
48            "Strides cant have more than 20 elements"
49        );
50        while shape.len() < 20 {
51            shape.push(0);
52        }
53        while strides.len() < 20 {
54            strides.push(0);
55        }
56        let shape_strides_len = self.shape().len() as u32;
57        let offset = self.offset() as u32;
58        let shape_as_uniform = self
59            .get_gpu()
60            .new_uniform_buffer(shape.as_slice().as_bytes());
61        let strides_as_uniform = self
62            .get_gpu()
63            .new_uniform_buffer(strides.as_slice().as_bytes());
64        let shape_strides_len_as_uniform = self
65            .get_gpu()
66            .new_uniform_buffer(shape_strides_len.as_bytes());
67        let offset_as_uniform = self.get_gpu().new_uniform_buffer(offset.as_bytes());
68        vec![
69            ShaderInput {
70                binding_id: binding_offset,
71                gpu_buffer: BufferType::Storage(self.internal_gpu_buffer()),
72            },
73            ShaderInput {
74                binding_id: binding_offset + 1,
75                gpu_buffer: BufferType::UniformOwned(shape_as_uniform),
76            },
77            ShaderInput {
78                binding_id: binding_offset + 2,
79                gpu_buffer: BufferType::UniformOwned(strides_as_uniform),
80            },
81            ShaderInput {
82                binding_id: binding_offset + 3,
83                gpu_buffer: BufferType::UniformOwned(shape_strides_len_as_uniform),
84            },
85            ShaderInput {
86                binding_id: binding_offset + 4,
87                gpu_buffer: BufferType::UniformOwned(offset_as_uniform),
88            },
89        ]
90    }
91}