Skip to main content

webgpu_groth16/gpu/
buffers.rs

1//! GPU buffer read-back operations.
2//!
3//! Handles asynchronous GPU → CPU data transfer via staging buffers.
4
5use futures::channel::oneshot;
6
7use super::GpuContext;
8use super::curve::GpuCurve;
9
10impl<C: GpuCurve> GpuContext<C> {
11    pub async fn read_buffer(
12        &self,
13        buffer: &wgpu::Buffer,
14        size: wgpu::BufferAddress,
15    ) -> anyhow::Result<Vec<u8>> {
16        let staging_buffer =
17            self.device.create_buffer(&wgpu::BufferDescriptor {
18                label: Some("Staging Read Buffer"),
19                size,
20                usage: wgpu::BufferUsages::MAP_READ
21                    | wgpu::BufferUsages::COPY_DST,
22                mapped_at_creation: false,
23            });
24
25        let mut encoder = self.device.create_command_encoder(
26            &wgpu::CommandEncoderDescriptor { label: None },
27        );
28        encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
29        self.queue.submit(Some(encoder.finish()));
30
31        let buffer_slice = staging_buffer.slice(..);
32        let (sender, receiver) = oneshot::channel();
33        buffer_slice.map_async(wgpu::MapMode::Read, move |res| {
34            sender.send(res).unwrap();
35        });
36
37        #[cfg(not(target_family = "wasm"))]
38        let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
39
40        if let Ok(Ok(())) = receiver.await {
41            let data = buffer_slice.get_mapped_range().to_vec();
42            _ = buffer_slice;
43            staging_buffer.unmap();
44            return Ok(data);
45        }
46        anyhow::bail!("Failed to read back from GPU buffer")
47    }
48
49    /// Reads multiple GPU buffers in a single command submission for
50    /// efficiency.
51    ///
52    /// All copy commands are batched into one encoder, submitted together, and
53    /// then all staging buffers are mapped concurrently. This avoids the
54    /// overhead of per-buffer submission and device polling.
55    pub async fn read_buffers_batch(
56        &self,
57        entries: &[(&wgpu::Buffer, wgpu::BufferAddress)],
58    ) -> anyhow::Result<Vec<Vec<u8>>> {
59        let mut staging = Vec::with_capacity(entries.len());
60        for (_, size) in entries {
61            staging.push(self.device.create_buffer(&wgpu::BufferDescriptor {
62                label: Some("Batch Staging Read Buffer"),
63                size: *size,
64                usage: wgpu::BufferUsages::MAP_READ
65                    | wgpu::BufferUsages::COPY_DST,
66                mapped_at_creation: false,
67            }));
68        }
69
70        let mut encoder = self.device.create_command_encoder(
71            &wgpu::CommandEncoderDescriptor {
72                label: Some("Batch Read Encoder"),
73            },
74        );
75        for (i, (src, size)) in entries.iter().enumerate() {
76            encoder.copy_buffer_to_buffer(src, 0, &staging[i], 0, *size);
77        }
78        self.queue.submit(Some(encoder.finish()));
79
80        let mut receivers = Vec::with_capacity(staging.len());
81        for s in &staging {
82            let slice = s.slice(..);
83            let (sender, receiver) = oneshot::channel();
84            slice.map_async(wgpu::MapMode::Read, move |res| {
85                let _ = sender.send(res);
86            });
87            receivers.push(receiver);
88        }
89
90        #[cfg(not(target_family = "wasm"))]
91        let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
92
93        for r in receivers {
94            match r.await {
95                Ok(Ok(())) => {}
96                _ => anyhow::bail!("Failed to map one of batch read buffers"),
97            }
98        }
99
100        let mut out = Vec::with_capacity(staging.len());
101        for s in staging {
102            let bytes = s.slice(..).get_mapped_range().to_vec();
103            s.unmap();
104            out.push(bytes);
105        }
106        Ok(out)
107    }
108}