1mod bfs;
8
9use self::bfs::bfs_reachability::bfs_reachability;
10use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
11use crate::runtime::cache::{BufferPool, PooledBuffer};
12use bytemuck::{Pod, Zeroable};
13use vyre::error::{Error, Result};
14use vyre::ops::graph::bfs::Bfs;
15use vyre::ops::graph::csr::CsrGraph;
16
17pub const WORKGROUP_SIZE: u32 = 64;
19
20pub const MAX_FINDINGS: usize = 1_000_000;
22
23#[repr(C)]
24#[derive(Copy, Clone, Debug, Pod, Zeroable)]
25pub(crate) struct GpuFinding {
26 start_node: u32,
27 sink_node: u32,
28 depth: u32,
29 source_idx: u32,
30}
31
32#[repr(C)]
33#[derive(Copy, Clone, Debug, Pod, Zeroable)]
34pub(crate) struct GpuParams {
35 num_sources: u32,
36 num_nodes: u32,
37 max_findings: u32,
38 max_depth: u32,
39 words_per_source: u32,
40 _pad0: u32,
41 _pad1: u32,
42 _pad2: u32,
43}
44
45pub(crate) fn validate_inputs(csr: &CsrGraph, sources: &[u32]) -> Result<()> {
46 csr.validate()?;
47 for &source in sources {
48 let source_index = usize::try_from(source).map_err(|err| Error::Dataflow {
49 message: format!(
50 "source node {source} cannot fit usize: {err}. Fix: reject this source on this platform."
51 ),
52 })?;
53 if source_index >= csr.node_count() {
54 return Err(Error::Dataflow {
55 message: format!(
56 "source node {source} is outside node_count {}. Fix: remove invalid sources before GPU dispatch.",
57 csr.node_count()
58 ),
59 });
60 }
61 }
62 Ok(())
63}
64
65pub(crate) fn checked_u32(value: usize, label: &str) -> Result<u32> {
66 u32::try_from(value).map_err(|source| Error::Dataflow {
67 message: format!("{label} value {value} exceeds u32::MAX: {source}. Fix: split the graph input before GPU dispatch."),
68 })
69}
70
71pub fn read_finding_count(
73 device: &wgpu::Device,
74 readback_count_buffer: &wgpu::Buffer,
75 max_findings: u32,
76 submission: wgpu::SubmissionIndex,
77) -> Result<u32> {
78 let slice = readback_count_buffer.slice(0..4);
79 let (tx, rx) = std::sync::mpsc::channel();
80 slice.map_async(wgpu::MapMode::Read, move |result| {
81 let _ = tx.send(result);
82 });
83 match device.poll(wgpu::Maintain::wait_for(submission)) {
84 wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
85 }
86 rx.recv()
87 .map_err(|error| Error::Dataflow {
88 message: format!("GPU finding count map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
89 })?
90 .map_err(|error| Error::Dataflow {
91 message: format!("GPU finding count map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
92 })?;
93 let mapped = slice.get_mapped_range();
94 let count = u32::from_ne_bytes(mapped[..4].try_into().map_err(|source| {
95 Error::Dataflow {
96 message: format!("GPU finding count readback was not four bytes: {source}. Fix: inspect the count readback buffer size."),
97 }
98 })?)
99 .min(max_findings);
100 drop(mapped);
101 readback_count_buffer.unmap();
102 Ok(count)
103}
104
105pub fn read_finding_rows(
107 device: &wgpu::Device,
108 readback_buffer: &wgpu::Buffer,
109 count: u32,
110 submission: wgpu::SubmissionIndex,
111) -> Result<Vec<(u32, u32, u32)>> {
112 let finding_byte_len = u64::from(count)
113 * u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
114 message: format!(
115 "GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
116 ),
117 })?;
118 let slice = readback_buffer.slice(0..finding_byte_len);
119 let (tx, rx) = std::sync::mpsc::channel();
120 slice.map_async(wgpu::MapMode::Read, move |result| {
121 let _ = tx.send(result);
122 });
123 match device.poll(wgpu::Maintain::wait_for(submission)) {
124 wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
125 }
126 rx.recv()
127 .map_err(|error| Error::Dataflow {
128 message: format!("GPU findings map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
129 })?
130 .map_err(|error| Error::Dataflow {
131 message: format!("GPU findings map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
132 })?;
133 let mapped = slice.get_mapped_range();
134 let findings: &[GpuFinding] = safe_cast_slice(&mapped).map_err(|error| Error::Dataflow {
135 message: format!(
136 "safe cast failed: {error}. Fix: ensure the readback buffer size matches GpuFinding layout."
137 ),
138 })?;
139 let results = findings
140 .iter()
141 .take(usize::try_from(count).map_err(|source| Error::Dataflow {
142 message: format!("finding count {count} cannot fit usize: {source}. Fix: reject this readback on this platform."),
143 })?)
144 .map(|finding| (finding.start_node, finding.sink_node, finding.depth))
145 .collect();
146 drop(mapped);
147 readback_buffer.unmap();
148 Ok(results)
149}
150
151pub(crate) fn read_findings(
152 device: &wgpu::Device,
153 readback_buffer: &wgpu::Buffer,
154 readback_count_buffer: &wgpu::Buffer,
155 max_findings: u32,
156 submission: wgpu::SubmissionIndex,
157) -> Result<Vec<(u32, u32, u32)>> {
158 let count = read_finding_count(
159 device,
160 readback_count_buffer,
161 max_findings,
162 submission.clone(),
163 )?;
164 if count == 0 {
165 return Ok(Vec::new());
166 }
167 read_finding_rows(device, readback_buffer, count, submission)
168}
169
170pub(crate) fn create_pipeline(
171 device: &wgpu::Device,
172 queue_slots: u32,
173) -> Result<wgpu::ComputePipeline> {
174 vyre::ops::registry::gate::verify_certificate(Bfs::SPEC.id()).map_err(|source| {
175 Error::Dataflow {
176 message: source.to_string(),
177 }
178 })?;
179 let source = vyre::lower::wgsl::lower(&Bfs::program_with_queue_size(queue_slots)).map_err(
180 |source| Error::Dataflow {
181 message: format!("failed to lower graph BFS IR to WGSL: {source}. Fix: validate the canonical BFS Program and lowerer before dispatch."),
182 },
183 )?;
184 crate::runtime::compile_compute_pipeline(device, "graph bfs pipeline", &source, "main")
185 .map_err(|source| Error::Dataflow {
186 message: format!(
187 "failed to compile graph BFS pipeline: {source}. Fix: repair the BFS WGSL or GPU runtime configuration."
188 ),
189 })
190}
191
192pub(crate) fn create_buffer<T: Pod>(
193 device: &wgpu::Device,
194 queue: &wgpu::Queue,
195 label: &'static str,
196 data: &[T],
197 usage: wgpu::BufferUsages,
198) -> Result<PooledBuffer> {
199 let contents = safe_bytes_of_slice(data);
200 let effective = if contents.is_empty() {
201 &[0u8; 16][..]
202 } else {
203 contents
204 };
205 let size = u64::try_from(effective.len()).map_err(|source| Error::Dataflow {
206 message: format!(
207 "buffer `{label}` has {} bytes that cannot fit u64: {source}. Fix: split the graph input before dispatch.",
208 effective.len()
209 ),
210 })?;
211 let buffer =
212 BufferPool::global().acquire(device, label, size, usage | wgpu::BufferUsages::COPY_DST)?;
213 queue.write_buffer(&buffer, 0, effective);
214 Ok(buffer)
215}
216
217pub(crate) struct BfsBindGroupInputs<'a> {
219 pub(crate) node_buffer: &'a wgpu::Buffer,
221 pub(crate) offset_buffer: &'a wgpu::Buffer,
223 pub(crate) target_buffer: &'a wgpu::Buffer,
225 pub(crate) source_buffer: &'a wgpu::Buffer,
227 pub(crate) findings_buffer: &'a wgpu::Buffer,
229 pub(crate) finding_count_buffer: &'a wgpu::Buffer,
231 pub(crate) params_buffer: &'a wgpu::Buffer,
233 pub(crate) visited_buffer: &'a wgpu::Buffer,
235}
236
237pub(crate) fn create_bind_group(
238 device: &wgpu::Device,
239 pipeline: &wgpu::ComputePipeline,
240 inputs: &BfsBindGroupInputs<'_>,
241) -> wgpu::BindGroup {
242 let layout = pipeline.get_bind_group_layout(0);
243 device.create_bind_group(&wgpu::BindGroupDescriptor {
244 label: Some("graph bfs bind group"),
245 layout: &layout,
246 entries: &[
247 wgpu::BindGroupEntry {
248 binding: 0,
249 resource: inputs.node_buffer.as_entire_binding(),
250 },
251 wgpu::BindGroupEntry {
252 binding: 1,
253 resource: inputs.offset_buffer.as_entire_binding(),
254 },
255 wgpu::BindGroupEntry {
256 binding: 2,
257 resource: inputs.target_buffer.as_entire_binding(),
258 },
259 wgpu::BindGroupEntry {
260 binding: 3,
261 resource: inputs.source_buffer.as_entire_binding(),
262 },
263 wgpu::BindGroupEntry {
264 binding: 4,
265 resource: inputs.findings_buffer.as_entire_binding(),
266 },
267 wgpu::BindGroupEntry {
268 binding: 5,
269 resource: inputs.finding_count_buffer.as_entire_binding(),
270 },
271 wgpu::BindGroupEntry {
272 binding: 6,
273 resource: inputs.params_buffer.as_entire_binding(),
274 },
275 wgpu::BindGroupEntry {
276 binding: 7,
277 resource: inputs.visited_buffer.as_entire_binding(),
278 },
279 ],
280 })
281}
282
283#[expect(
284 clippy::too_many_arguments,
285 reason = "GPU dispatch helpers pass distinct wgpu handles whose grouping would hide binding roles"
286)]
287pub(crate) fn dispatch_and_copy(
288 device: &wgpu::Device,
289 queue: &wgpu::Queue,
290 pipeline: &wgpu::ComputePipeline,
291 bind_group: &wgpu::BindGroup,
292 findings_buffer: &wgpu::Buffer,
293 finding_count_buffer: &wgpu::Buffer,
294 findings_size: u64,
295 workgroup_count: u32,
296 encoder: Option<&mut wgpu::CommandEncoder>,
297) -> Result<(PooledBuffer, PooledBuffer, wgpu::SubmissionIndex)> {
298 let pool = BufferPool::global();
299 let readback_buffer = pool.acquire(
300 device,
301 "graph findings readback",
302 findings_size,
303 wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
304 )?;
305 let readback_count_buffer = pool.acquire(
306 device,
307 "graph finding count readback",
308 4,
309 wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
310 )?;
311 let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
312 device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
313 label: Some("graph bfs encoder"),
314 })
315 });
316 let encoder = if let Some(encoder) = encoder {
317 encoder
318 } else {
319 owned_encoder
320 .as_mut()
321 .expect("owned encoder must be present when encoder is omitted")
322 };
323 {
324 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
325 label: Some("graph bfs pass"),
326 timestamp_writes: None,
327 });
328 pass.set_pipeline(pipeline);
329 pass.set_bind_group(0, bind_group, &[]);
330 pass.dispatch_workgroups(workgroup_count, 1, 1);
331 }
332 encoder.copy_buffer_to_buffer(findings_buffer, 0, &readback_buffer, 0, findings_size);
333 encoder.copy_buffer_to_buffer(finding_count_buffer, 0, &readback_count_buffer, 0, 4);
334 let submission = if let Some(encoder) = owned_encoder {
335 queue.submit(Some(encoder.finish()))
336 } else {
337 return Err(Error::Dataflow {
338 message: "dispatch_and_copy was called with an external encoder. Submit the caller-owned command encoder before readback. Fix: call with `None` for immediate submit behavior.".to_string(),
339 });
340 };
341 Ok((readback_buffer, readback_count_buffer, submission))
342}
343
344pub fn bfs_reachability_cached(
350 csr: &CsrGraph,
351 sources: &[u32],
352 max_depth: u32,
353) -> Result<Vec<(u32, u32, u32)>> {
354 let (device, queue) = crate::runtime::cached_device()?;
355 bfs_reachability(device, queue, csr, sources, max_depth, None)
356}
357
358