webgpu_groth16/gpu/
buffers.rs1use futures::channel::oneshot;
6
7use super::GpuContext;
8use super::curve::GpuCurve;
9
10impl<C: GpuCurve> GpuContext<C> {
11 pub async fn read_buffer(
12 &self,
13 buffer: &wgpu::Buffer,
14 size: wgpu::BufferAddress,
15 ) -> anyhow::Result<Vec<u8>> {
16 let staging_buffer =
17 self.device.create_buffer(&wgpu::BufferDescriptor {
18 label: Some("Staging Read Buffer"),
19 size,
20 usage: wgpu::BufferUsages::MAP_READ
21 | wgpu::BufferUsages::COPY_DST,
22 mapped_at_creation: false,
23 });
24
25 let mut encoder = self.device.create_command_encoder(
26 &wgpu::CommandEncoderDescriptor { label: None },
27 );
28 encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
29 self.queue.submit(Some(encoder.finish()));
30
31 let buffer_slice = staging_buffer.slice(..);
32 let (sender, receiver) = oneshot::channel();
33 buffer_slice.map_async(wgpu::MapMode::Read, move |res| {
34 sender.send(res).unwrap();
35 });
36
37 #[cfg(not(target_family = "wasm"))]
38 let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
39
40 if let Ok(Ok(())) = receiver.await {
41 let data = buffer_slice.get_mapped_range().to_vec();
42 _ = buffer_slice;
43 staging_buffer.unmap();
44 return Ok(data);
45 }
46 anyhow::bail!("Failed to read back from GPU buffer")
47 }
48
49 pub async fn read_buffers_batch(
56 &self,
57 entries: &[(&wgpu::Buffer, wgpu::BufferAddress)],
58 ) -> anyhow::Result<Vec<Vec<u8>>> {
59 let mut staging = Vec::with_capacity(entries.len());
60 for (_, size) in entries {
61 staging.push(self.device.create_buffer(&wgpu::BufferDescriptor {
62 label: Some("Batch Staging Read Buffer"),
63 size: *size,
64 usage: wgpu::BufferUsages::MAP_READ
65 | wgpu::BufferUsages::COPY_DST,
66 mapped_at_creation: false,
67 }));
68 }
69
70 let mut encoder = self.device.create_command_encoder(
71 &wgpu::CommandEncoderDescriptor {
72 label: Some("Batch Read Encoder"),
73 },
74 );
75 for (i, (src, size)) in entries.iter().enumerate() {
76 encoder.copy_buffer_to_buffer(src, 0, &staging[i], 0, *size);
77 }
78 self.queue.submit(Some(encoder.finish()));
79
80 let mut receivers = Vec::with_capacity(staging.len());
81 for s in &staging {
82 let slice = s.slice(..);
83 let (sender, receiver) = oneshot::channel();
84 slice.map_async(wgpu::MapMode::Read, move |res| {
85 let _ = sender.send(res);
86 });
87 receivers.push(receiver);
88 }
89
90 #[cfg(not(target_family = "wasm"))]
91 let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
92
93 for r in receivers {
94 match r.await {
95 Ok(Ok(())) => {}
96 _ => anyhow::bail!("Failed to map one of batch read buffers"),
97 }
98 }
99
100 let mut out = Vec::with_capacity(staging.len());
101 for s in staging {
102 let bytes = s.slice(..).get_mapped_range().to_vec();
103 s.unmap();
104 out.push(bytes);
105 }
106 Ok(out)
107 }
108}