Skip to main content

wgpu_burn_interop/
interop.rs

1use core::panic;
2
3use burn::tensor::{Int, Tensor, TensorMetadata};
4
5use burn_cubecl::tensor::CubeTensor;
6use cubecl::wgpu::WgpuRuntime;
7use gloss_burn_multibackend::{backend::MultiBackend, tensor::MultiFloatTensor, tensor::MultiIntTensor};
8
9pub fn tensor_float2wgpu_buffer(
10    tensor: Tensor<MultiBackend, 2>,
11    usages: wgpu::BufferUsages,
12    device: &wgpu::Device,
13    queue: &wgpu::Queue,
14) -> wgpu::Buffer {
15    // Get underlying cube tensor.
16    let cube_tensor = tensor.into_primitive().tensor();
17    let MultiFloatTensor::Wgpu(cube_tensor) = cube_tensor else {
18        panic!("Expected wgpu tensor got {:?}", cube_tensor.dtype())
19    };
20
21    cubewgpu_tensor2wgpu_buffer(cube_tensor, usages, device, queue)
22}
23
24pub fn tensor_int2wgpu_buffer(
25    tensor: Tensor<MultiBackend, 2, Int>,
26    usages: wgpu::BufferUsages,
27    device: &wgpu::Device,
28    queue: &wgpu::Queue,
29) -> wgpu::Buffer {
30    // Get underlying cube tensor.
31    let cube_tensor = tensor.into_primitive();
32    let MultiIntTensor::Wgpu(cube_tensor) = cube_tensor else {
33        panic!("Expected wgpu tensor got {:?}", cube_tensor.dtype())
34    };
35
36    cubewgpu_tensor2wgpu_buffer(cube_tensor, usages, device, queue)
37}
38
39fn cubewgpu_tensor2wgpu_buffer(
40    tensor: CubeTensor<WgpuRuntime>,
41    usages: wgpu::BufferUsages,
42    device: &wgpu::Device,
43    queue: &wgpu::Queue,
44) -> wgpu::Buffer {
45    // Get the 'resource' from the client
46    let client = tensor.client;
47    let binding = client.get_resource(tensor.handle.clone().binding());
48    let resource = binding.resource();
49
50    // Which has the wgpu buffer.
51    let buffer = resource.buffer();
52
53    // But do note it only uses a part of the buffer, see offset + size.
54    let offset = resource.offset();
55    let size = resource.size();
56
57    // Client buffers the pending work, so flush first in order to make sure it's queued.
58    // no need to sync since since on the same queue as wgpu submission submission queue (assuming you are using the same device between burn and wgpu) (see crate wgpu_burn_global_device)
59    client.flush();
60    // client.sync().block_on();
61
62    // Create destination buffer
63    let dst_buffer = device.create_buffer(&wgpu::BufferDescriptor {
64        label: Some("tensor2wgpu_buffer_dst"),
65        size,
66        usage: wgpu::BufferUsages::COPY_DST | usages,
67        mapped_at_creation: false,
68    });
69
70    // Encode the copy
71    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
72        label: Some("tensor2wgpu_buffer_copy_encoder"),
73    });
74
75    encoder.copy_buffer_to_buffer(buffer, offset, &dst_buffer, 0, size);
76
77    // Submit
78    queue.submit(Some(encoder.finish()));
79
80    dst_buffer
81}