Skip to main content

vyre_wgpu/engine/dataflow/
mod.rs

1//! Generic GPU graph dataflow engines.
2//!
3//! Host-side workflow dispatcher that owns GPU resources (buffers, pipelines,
4//! readback), runs Programs from `vyre::ops::graph`, and returns typed host
5//! results.
6
7mod bfs;
8
9use self::bfs::bfs_reachability::bfs_reachability;
10use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
11use crate::runtime::cache::{BufferPool, PooledBuffer};
12use bytemuck::{Pod, Zeroable};
13use vyre::error::{Error, Result};
14use vyre::ops::graph::bfs::Bfs;
15use vyre::ops::graph::csr::CsrGraph;
16
17/// `WORKGROUP_SIZE` constant.
18pub const WORKGROUP_SIZE: u32 = 64;
19
20/// Maximum BFS findings allocated by one dispatch.
21pub const MAX_FINDINGS: usize = 1_000_000;
22
23#[repr(C)]
24#[derive(Copy, Clone, Debug, Pod, Zeroable)]
25pub(crate) struct GpuFinding {
26    start_node: u32,
27    sink_node: u32,
28    depth: u32,
29    source_idx: u32,
30}
31
32#[repr(C)]
33#[derive(Copy, Clone, Debug, Pod, Zeroable)]
34pub(crate) struct GpuParams {
35    num_sources: u32,
36    num_nodes: u32,
37    max_findings: u32,
38    max_depth: u32,
39    words_per_source: u32,
40    _pad0: u32,
41    _pad1: u32,
42    _pad2: u32,
43}
44
45pub(crate) fn validate_inputs(csr: &CsrGraph, sources: &[u32]) -> Result<()> {
46    csr.validate()?;
47    for &source in sources {
48        let source_index = usize::try_from(source).map_err(|err| Error::Dataflow {
49            message: format!(
50                "source node {source} cannot fit usize: {err}. Fix: reject this source on this platform."
51            ),
52        })?;
53        if source_index >= csr.node_count() {
54            return Err(Error::Dataflow {
55                message: format!(
56                    "source node {source} is outside node_count {}. Fix: remove invalid sources before GPU dispatch.",
57                    csr.node_count()
58                ),
59            });
60        }
61    }
62    Ok(())
63}
64
65pub(crate) fn checked_u32(value: usize, label: &str) -> Result<u32> {
66    u32::try_from(value).map_err(|source| Error::Dataflow {
67        message: format!("{label} value {value} exceeds u32::MAX: {source}. Fix: split the graph input before GPU dispatch."),
68    })
69}
70
71/// `read_finding_count` function.
72pub fn read_finding_count(
73    device: &wgpu::Device,
74    readback_count_buffer: &wgpu::Buffer,
75    max_findings: u32,
76    submission: wgpu::SubmissionIndex,
77) -> Result<u32> {
78    let slice = readback_count_buffer.slice(0..4);
79    let (tx, rx) = std::sync::mpsc::channel();
80    slice.map_async(wgpu::MapMode::Read, move |result| {
81        let _ = tx.send(result);
82    });
83    match device.poll(wgpu::Maintain::wait_for(submission)) {
84        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
85    }
86    rx.recv()
87        .map_err(|error| Error::Dataflow {
88            message: format!("GPU finding count map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
89        })?
90        .map_err(|error| Error::Dataflow {
91            message: format!("GPU finding count map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
92        })?;
93    let mapped = slice.get_mapped_range();
94    let count = u32::from_ne_bytes(mapped[..4].try_into().map_err(|source| {
95        Error::Dataflow {
96            message: format!("GPU finding count readback was not four bytes: {source}. Fix: inspect the count readback buffer size."),
97        }
98    })?)
99    .min(max_findings);
100    drop(mapped);
101    readback_count_buffer.unmap();
102    Ok(count)
103}
104
105/// `read_finding_rows` function.
106pub fn read_finding_rows(
107    device: &wgpu::Device,
108    readback_buffer: &wgpu::Buffer,
109    count: u32,
110    submission: wgpu::SubmissionIndex,
111) -> Result<Vec<(u32, u32, u32)>> {
112    let finding_byte_len = u64::from(count)
113        * u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
114            message: format!(
115                "GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
116            ),
117        })?;
118    let slice = readback_buffer.slice(0..finding_byte_len);
119    let (tx, rx) = std::sync::mpsc::channel();
120    slice.map_async(wgpu::MapMode::Read, move |result| {
121        let _ = tx.send(result);
122    });
123    match device.poll(wgpu::Maintain::wait_for(submission)) {
124        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
125    }
126    rx.recv()
127        .map_err(|error| Error::Dataflow {
128            message: format!("GPU findings map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
129        })?
130        .map_err(|error| Error::Dataflow {
131            message: format!("GPU findings map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
132        })?;
133    let mapped = slice.get_mapped_range();
134    let findings: &[GpuFinding] = safe_cast_slice(&mapped).map_err(|error| Error::Dataflow {
135        message: format!(
136            "safe cast failed: {error}. Fix: ensure the readback buffer size matches GpuFinding layout."
137        ),
138    })?;
139    let results = findings
140        .iter()
141        .take(usize::try_from(count).map_err(|source| Error::Dataflow {
142            message: format!("finding count {count} cannot fit usize: {source}. Fix: reject this readback on this platform."),
143        })?)
144        .map(|finding| (finding.start_node, finding.sink_node, finding.depth))
145        .collect();
146    drop(mapped);
147    readback_buffer.unmap();
148    Ok(results)
149}
150
151pub(crate) fn read_findings(
152    device: &wgpu::Device,
153    readback_buffer: &wgpu::Buffer,
154    readback_count_buffer: &wgpu::Buffer,
155    max_findings: u32,
156    submission: wgpu::SubmissionIndex,
157) -> Result<Vec<(u32, u32, u32)>> {
158    let count = read_finding_count(
159        device,
160        readback_count_buffer,
161        max_findings,
162        submission.clone(),
163    )?;
164    if count == 0 {
165        return Ok(Vec::new());
166    }
167    read_finding_rows(device, readback_buffer, count, submission)
168}
169
170pub(crate) fn create_pipeline(
171    device: &wgpu::Device,
172    queue_slots: u32,
173) -> Result<wgpu::ComputePipeline> {
174    vyre::ops::registry::gate::verify_certificate(Bfs::SPEC.id()).map_err(|source| {
175        Error::Dataflow {
176            message: source.to_string(),
177        }
178    })?;
179    let source = vyre::lower::wgsl::lower(&Bfs::program_with_queue_size(queue_slots)).map_err(
180        |source| Error::Dataflow {
181            message: format!("failed to lower graph BFS IR to WGSL: {source}. Fix: validate the canonical BFS Program and lowerer before dispatch."),
182        },
183    )?;
184    crate::runtime::compile_compute_pipeline(device, "graph bfs pipeline", &source, "main")
185        .map_err(|source| Error::Dataflow {
186            message: format!(
187                "failed to compile graph BFS pipeline: {source}. Fix: repair the BFS WGSL or GPU runtime configuration."
188            ),
189        })
190}
191
192pub(crate) fn create_buffer<T: Pod>(
193    device: &wgpu::Device,
194    queue: &wgpu::Queue,
195    label: &'static str,
196    data: &[T],
197    usage: wgpu::BufferUsages,
198) -> Result<PooledBuffer> {
199    let contents = safe_bytes_of_slice(data);
200    let effective = if contents.is_empty() {
201        &[0u8; 16][..]
202    } else {
203        contents
204    };
205    let size = u64::try_from(effective.len()).map_err(|source| Error::Dataflow {
206        message: format!(
207            "buffer `{label}` has {} bytes that cannot fit u64: {source}. Fix: split the graph input before dispatch.",
208            effective.len()
209        ),
210    })?;
211    let buffer =
212        BufferPool::global().acquire(device, label, size, usage | wgpu::BufferUsages::COPY_DST)?;
213    queue.write_buffer(&buffer, 0, effective);
214    Ok(buffer)
215}
216
217/// Buffers bound by the graph BFS compute pipeline.
218pub(crate) struct BfsBindGroupInputs<'a> {
219    /// CSR node metadata buffer.
220    pub(crate) node_buffer: &'a wgpu::Buffer,
221    /// CSR row-offset buffer.
222    pub(crate) offset_buffer: &'a wgpu::Buffer,
223    /// CSR adjacency target buffer.
224    pub(crate) target_buffer: &'a wgpu::Buffer,
225    /// Starting source-node buffer.
226    pub(crate) source_buffer: &'a wgpu::Buffer,
227    /// Storage buffer receiving discovered `(start, sink, depth)` rows.
228    pub(crate) findings_buffer: &'a wgpu::Buffer,
229    /// Single-word storage buffer receiving the findings count.
230    pub(crate) finding_count_buffer: &'a wgpu::Buffer,
231    /// Dispatch parameter block.
232    pub(crate) params_buffer: &'a wgpu::Buffer,
233    /// Per-source visited bitset buffer.
234    pub(crate) visited_buffer: &'a wgpu::Buffer,
235}
236
237pub(crate) fn create_bind_group(
238    device: &wgpu::Device,
239    pipeline: &wgpu::ComputePipeline,
240    inputs: &BfsBindGroupInputs<'_>,
241) -> wgpu::BindGroup {
242    let layout = pipeline.get_bind_group_layout(0);
243    device.create_bind_group(&wgpu::BindGroupDescriptor {
244        label: Some("graph bfs bind group"),
245        layout: &layout,
246        entries: &[
247            wgpu::BindGroupEntry {
248                binding: 0,
249                resource: inputs.node_buffer.as_entire_binding(),
250            },
251            wgpu::BindGroupEntry {
252                binding: 1,
253                resource: inputs.offset_buffer.as_entire_binding(),
254            },
255            wgpu::BindGroupEntry {
256                binding: 2,
257                resource: inputs.target_buffer.as_entire_binding(),
258            },
259            wgpu::BindGroupEntry {
260                binding: 3,
261                resource: inputs.source_buffer.as_entire_binding(),
262            },
263            wgpu::BindGroupEntry {
264                binding: 4,
265                resource: inputs.findings_buffer.as_entire_binding(),
266            },
267            wgpu::BindGroupEntry {
268                binding: 5,
269                resource: inputs.finding_count_buffer.as_entire_binding(),
270            },
271            wgpu::BindGroupEntry {
272                binding: 6,
273                resource: inputs.params_buffer.as_entire_binding(),
274            },
275            wgpu::BindGroupEntry {
276                binding: 7,
277                resource: inputs.visited_buffer.as_entire_binding(),
278            },
279        ],
280    })
281}
282
283#[expect(
284    clippy::too_many_arguments,
285    reason = "GPU dispatch helpers pass distinct wgpu handles whose grouping would hide binding roles"
286)]
287pub(crate) fn dispatch_and_copy(
288    device: &wgpu::Device,
289    queue: &wgpu::Queue,
290    pipeline: &wgpu::ComputePipeline,
291    bind_group: &wgpu::BindGroup,
292    findings_buffer: &wgpu::Buffer,
293    finding_count_buffer: &wgpu::Buffer,
294    findings_size: u64,
295    workgroup_count: u32,
296    encoder: Option<&mut wgpu::CommandEncoder>,
297) -> Result<(PooledBuffer, PooledBuffer, wgpu::SubmissionIndex)> {
298    let pool = BufferPool::global();
299    let readback_buffer = pool.acquire(
300        device,
301        "graph findings readback",
302        findings_size,
303        wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
304    )?;
305    let readback_count_buffer = pool.acquire(
306        device,
307        "graph finding count readback",
308        4,
309        wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
310    )?;
311    let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
312        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
313            label: Some("graph bfs encoder"),
314        })
315    });
316    let encoder = if let Some(encoder) = encoder {
317        encoder
318    } else {
319        owned_encoder
320            .as_mut()
321            .expect("owned encoder must be present when encoder is omitted")
322    };
323    {
324        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
325            label: Some("graph bfs pass"),
326            timestamp_writes: None,
327        });
328        pass.set_pipeline(pipeline);
329        pass.set_bind_group(0, bind_group, &[]);
330        pass.dispatch_workgroups(workgroup_count, 1, 1);
331    }
332    encoder.copy_buffer_to_buffer(findings_buffer, 0, &readback_buffer, 0, findings_size);
333    encoder.copy_buffer_to_buffer(finding_count_buffer, 0, &readback_count_buffer, 0, 4);
334    let submission = if let Some(encoder) = owned_encoder {
335        queue.submit(Some(encoder.finish()))
336    } else {
337        return Err(Error::Dataflow {
338            message: "dispatch_and_copy was called with an external encoder. Submit the caller-owned command encoder before readback. Fix: call with `None` for immediate submit behavior.".to_string(),
339        });
340    };
341    Ok((readback_buffer, readback_count_buffer, submission))
342}
343
344/// Run multi-source BFS reachability using vyre's cached runtime device.
345///
346/// # Errors
347///
348/// Returns `Error::Gpu` if the cached GPU device cannot be initialized.
349pub fn bfs_reachability_cached(
350    csr: &CsrGraph,
351    sources: &[u32],
352    max_depth: u32,
353) -> Result<Vec<(u32, u32, u32)>> {
354    let (device, queue) = crate::runtime::cached_device()?;
355    bfs_reachability(device, queue, csr, sources, max_depth, None)
356}
357
358// The main bfs_reachability function lives in its own file due to size.