Skip to main content

sp1_gpu_cudart/tensor/
transpose.rs

1use sp1_gpu_sys::{
2    runtime::{Dim3, KernelPtr},
3    transpose::{
4        transpose_kernel_koala_bear, transpose_kernel_koala_bear_digest,
5        transpose_kernel_koala_bear_extension, transpose_kernel_u32, transpose_kernel_u32_digest,
6    },
7};
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9// TransposeBackend removed - using DeviceTensor methods instead
10
11use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
12
13/// # Safety
14pub unsafe trait DeviceTransposeKernel<T> {
15    fn transpose_kernel() -> KernelPtr;
16}
17
18impl<T: DeviceCopy> DeviceTensor<T>
19where
20    TaskScope: DeviceTransposeKernel<T>,
21{
22    /// Transposes the tensor into the given destination tensor.
23    pub fn transpose_into(&self, dst: &mut DeviceTensor<T>) {
24        let src = &self.raw;
25        let mut dst_view = dst.raw.as_view_mut();
26        let num_dims = src.sizes().len();
27
28        let dim_x = src.sizes()[num_dims - 2];
29        let dim_y = src.sizes()[num_dims - 1];
30        let dim_z: usize = src.sizes().iter().take(num_dims - 2).product();
31        assert_eq!(dim_x, dst_view.sizes()[num_dims - 1]);
32        assert_eq!(dim_y, dst_view.sizes()[num_dims - 2]);
33
34        let block_dim: Dim3 = (32u32, 32u32, 1u32).into();
35        let grid_dim: Dim3 = (
36            dim_x.div_ceil(block_dim.x as usize),
37            dim_y.div_ceil(block_dim.y as usize),
38            dim_z.div_ceil(block_dim.z as usize),
39        )
40            .into();
41        let args = args!(src.as_ptr(), dst_view.as_mut_ptr(), dim_x, dim_y, dim_z);
42        unsafe {
43            src.backend()
44                .launch_kernel(TaskScope::transpose_kernel(), grid_dim, block_dim, &args, 0)
45                .unwrap();
46        }
47    }
48
49    /// Transposes the tensor and returns a new tensor.
50    pub fn transpose(&self) -> DeviceTensor<T> {
51        let src = &self.raw;
52        let num_dims = src.sizes().len();
53        let mut transposed_sizes = src.sizes().to_vec();
54        transposed_sizes.swap(num_dims - 2, num_dims - 1);
55        let mut dst = DeviceTensor::with_sizes_in(&transposed_sizes, src.backend().clone());
56        unsafe {
57            dst.assume_init();
58        }
59        self.transpose_into(&mut dst);
60        dst
61    }
62}
63
64unsafe impl DeviceTransposeKernel<u32> for TaskScope {
65    fn transpose_kernel() -> KernelPtr {
66        unsafe { transpose_kernel_u32() }
67    }
68}
69
70unsafe impl DeviceTransposeKernel<[u32; 8]> for TaskScope {
71    fn transpose_kernel() -> KernelPtr {
72        unsafe { transpose_kernel_u32_digest() }
73    }
74}
75
76unsafe impl DeviceTransposeKernel<SP1Field> for TaskScope {
77    fn transpose_kernel() -> KernelPtr {
78        unsafe { transpose_kernel_koala_bear() }
79    }
80}
81
82unsafe impl DeviceTransposeKernel<SP1ExtensionField> for TaskScope {
83    fn transpose_kernel() -> KernelPtr {
84        unsafe { transpose_kernel_koala_bear_extension() }
85    }
86}
87
88unsafe impl DeviceTransposeKernel<[SP1Field; 8]> for TaskScope {
89    fn transpose_kernel() -> KernelPtr {
90        unsafe { transpose_kernel_koala_bear_digest() }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use slop_tensor::Tensor;
97
98    use super::*;
99
100    #[test]
101    fn test_tensor_transpose() {
102        let mut rng = rand::thread_rng();
103
104        for (height, width) in [
105            (1024, 1024),
106            (1024, 6),
107            (6, 1024),
108            (1024, 6),
109            (1024, 2048),
110            (2048, 1024),
111            (2048, 2048),
112            (1 << 22, 100),
113        ] {
114            let tensor = Tensor::<u32>::rand(&mut rng, [height, width]);
115            let transposed_expected = tensor.transpose();
116            let transposed = crate::run_sync_in_place(|t| {
117                let device_tensor = DeviceTensor::from_host(&tensor, &t).unwrap();
118                let transposed = device_tensor.transpose();
119                transposed.to_host().unwrap()
120            })
121            .unwrap();
122
123            for (val, expected) in
124                transposed.as_buffer().iter().zip(transposed_expected.as_buffer().iter())
125            {
126                assert_eq!(*val, *expected);
127            }
128        }
129    }
130}