sp1_gpu_cudart/tensor/
transpose.rs1use sp1_gpu_sys::{
2 runtime::{Dim3, KernelPtr},
3 transpose::{
4 transpose_kernel_koala_bear, transpose_kernel_koala_bear_digest,
5 transpose_kernel_koala_bear_extension, transpose_kernel_u32, transpose_kernel_u32_digest,
6 },
7};
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
12
13pub unsafe trait DeviceTransposeKernel<T> {
15 fn transpose_kernel() -> KernelPtr;
16}
17
18impl<T: DeviceCopy> DeviceTensor<T>
19where
20 TaskScope: DeviceTransposeKernel<T>,
21{
22 pub fn transpose_into(&self, dst: &mut DeviceTensor<T>) {
24 let src = &self.raw;
25 let mut dst_view = dst.raw.as_view_mut();
26 let num_dims = src.sizes().len();
27
28 let dim_x = src.sizes()[num_dims - 2];
29 let dim_y = src.sizes()[num_dims - 1];
30 let dim_z: usize = src.sizes().iter().take(num_dims - 2).product();
31 assert_eq!(dim_x, dst_view.sizes()[num_dims - 1]);
32 assert_eq!(dim_y, dst_view.sizes()[num_dims - 2]);
33
34 let block_dim: Dim3 = (32u32, 32u32, 1u32).into();
35 let grid_dim: Dim3 = (
36 dim_x.div_ceil(block_dim.x as usize),
37 dim_y.div_ceil(block_dim.y as usize),
38 dim_z.div_ceil(block_dim.z as usize),
39 )
40 .into();
41 let args = args!(src.as_ptr(), dst_view.as_mut_ptr(), dim_x, dim_y, dim_z);
42 unsafe {
43 src.backend()
44 .launch_kernel(TaskScope::transpose_kernel(), grid_dim, block_dim, &args, 0)
45 .unwrap();
46 }
47 }
48
49 pub fn transpose(&self) -> DeviceTensor<T> {
51 let src = &self.raw;
52 let num_dims = src.sizes().len();
53 let mut transposed_sizes = src.sizes().to_vec();
54 transposed_sizes.swap(num_dims - 2, num_dims - 1);
55 let mut dst = DeviceTensor::with_sizes_in(&transposed_sizes, src.backend().clone());
56 unsafe {
57 dst.assume_init();
58 }
59 self.transpose_into(&mut dst);
60 dst
61 }
62}
63
64unsafe impl DeviceTransposeKernel<u32> for TaskScope {
65 fn transpose_kernel() -> KernelPtr {
66 unsafe { transpose_kernel_u32() }
67 }
68}
69
70unsafe impl DeviceTransposeKernel<[u32; 8]> for TaskScope {
71 fn transpose_kernel() -> KernelPtr {
72 unsafe { transpose_kernel_u32_digest() }
73 }
74}
75
76unsafe impl DeviceTransposeKernel<SP1Field> for TaskScope {
77 fn transpose_kernel() -> KernelPtr {
78 unsafe { transpose_kernel_koala_bear() }
79 }
80}
81
82unsafe impl DeviceTransposeKernel<SP1ExtensionField> for TaskScope {
83 fn transpose_kernel() -> KernelPtr {
84 unsafe { transpose_kernel_koala_bear_extension() }
85 }
86}
87
88unsafe impl DeviceTransposeKernel<[SP1Field; 8]> for TaskScope {
89 fn transpose_kernel() -> KernelPtr {
90 unsafe { transpose_kernel_koala_bear_digest() }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use slop_tensor::Tensor;
97
98 use super::*;
99
100 #[test]
101 fn test_tensor_transpose() {
102 let mut rng = rand::thread_rng();
103
104 for (height, width) in [
105 (1024, 1024),
106 (1024, 6),
107 (6, 1024),
108 (1024, 6),
109 (1024, 2048),
110 (2048, 1024),
111 (2048, 2048),
112 (1 << 22, 100),
113 ] {
114 let tensor = Tensor::<u32>::rand(&mut rng, [height, width]);
115 let transposed_expected = tensor.transpose();
116 let transposed = crate::run_sync_in_place(|t| {
117 let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
118 let transposed = device_tensor.transpose();
119 transposed.to_host().unwrap()
120 })
121 .unwrap();
122
123 for (val, expected) in
124 transposed.as_buffer().iter().zip(transposed_expected.as_buffer().iter())
125 {
126 assert_eq!(*val, *expected);
127 }
128 }
129 }
130}