tensor_compute/tensors/gpu_tensor/
traits.rs1use crate::gpu_internals::gpu_buffers::GpuBuffer;
2use crate::gpu_internals::shader_runner::{BufferType, ShaderInput};
3use crate::gpu_internals::GpuInstance;
4use crate::{CpuTensor, ShapeStrideTrait};
5use async_trait::async_trait;
6use zerocopy::AsBytes;
7
8#[async_trait(?Send)]
9pub trait GpuAllocated {
10 fn get_gpu(&self) -> &'static GpuInstance;
11 fn internal_gpu_buffer(&self) -> &GpuBuffer;
12 fn internal_buffer_size_in_bytes(&self) -> usize {
13 self.internal_gpu_buffer().size_bytes()
14 }
15}
16
17#[async_trait(?Send)]
18pub trait CpuTransferable {
19 async fn to_cpu(&self) -> CpuTensor;
20}
21
22#[async_trait(?Send)]
23impl<T> CpuTransferable for T
24where
25 T: GpuAllocated + ShapeStrideTrait,
26{
27 async fn to_cpu(&self) -> CpuTensor {
28 let gpu = self.get_gpu();
29 let buffer_in_cpu_mem = gpu.copy_buffer_to_cpu_mem(self.internal_gpu_buffer()).await;
30 CpuTensor::new_with_strides_and_offset(
31 buffer_in_cpu_mem,
32 self.shape().clone(),
33 self.strides().clone(),
34 self.offset(),
35 )
36 }
37}
38
39pub trait AsShaderInput: GpuAllocated + ShapeStrideTrait {
40 fn to_shader_inputs(&self, binding_offset: usize) -> Vec<ShaderInput> {
41 let mut shape: Vec<u128> = self.shape().iter().map(|&e| e as u128).collect();
42 let mut strides: Vec<u128> = self.strides().iter().map(|&e| e as u128).collect();
43 assert!(shape.len() <= 20, "Shape cant have more than 20 elements");
46 assert!(
47 strides.len() <= 20,
48 "Strides cant have more than 20 elements"
49 );
50 while shape.len() < 20 {
51 shape.push(0);
52 }
53 while strides.len() < 20 {
54 strides.push(0);
55 }
56 let shape_strides_len = self.shape().len() as u32;
57 let offset = self.offset() as u32;
58 let shape_as_uniform = self
59 .get_gpu()
60 .new_uniform_buffer(shape.as_slice().as_bytes());
61 let strides_as_uniform = self
62 .get_gpu()
63 .new_uniform_buffer(strides.as_slice().as_bytes());
64 let shape_strides_len_as_uniform = self
65 .get_gpu()
66 .new_uniform_buffer(shape_strides_len.as_bytes());
67 let offset_as_uniform = self.get_gpu().new_uniform_buffer(offset.as_bytes());
68 vec![
69 ShaderInput {
70 binding_id: binding_offset,
71 gpu_buffer: BufferType::Storage(self.internal_gpu_buffer()),
72 },
73 ShaderInput {
74 binding_id: binding_offset + 1,
75 gpu_buffer: BufferType::UniformOwned(shape_as_uniform),
76 },
77 ShaderInput {
78 binding_id: binding_offset + 2,
79 gpu_buffer: BufferType::UniformOwned(strides_as_uniform),
80 },
81 ShaderInput {
82 binding_id: binding_offset + 3,
83 gpu_buffer: BufferType::UniformOwned(shape_strides_len_as_uniform),
84 },
85 ShaderInput {
86 binding_id: binding_offset + 4,
87 gpu_buffer: BufferType::UniformOwned(offset_as_uniform),
88 },
89 ]
90 }
91}