Skip to main content

vyre_driver_cuda/
egraph_device_image.rs

1//! CUDA upload planning for GPU e-graph device images.
2//!
3//! The e-graph substrate stays in `vyre-foundation`; this module only
4//! translates its validated u32 device image into CUDA byte spans. That keeps
5//! equality-saturation semantics out of the backend while giving the CUDA
6//! path a single-copy upload contract.
7
8use std::fmt;
9
10use vyre_driver::BackendError;
11use vyre_foundation::optimizer::eqsat_gpu::{
12    GpuEGraphDeviceImage, GpuEGraphDeviceImageError, GpuEGraphDeviceLayout, GpuEGraphDeviceSpan,
13    GpuEGraphSnapshot,
14};
15
16use crate::backend::{CudaBackend, CudaResidentBuffer};
17use crate::numeric::CUDA_NUMERIC;
18
19/// Error returned when a CUDA e-graph upload plan cannot be built.
20#[derive(Clone, Debug, Eq, PartialEq)]
21pub enum CudaEGraphDeviceUploadError {
22    /// Foundation image packing rejected the snapshot.
23    Image(GpuEGraphDeviceImageError),
24    /// A word span could not be represented as byte offsets.
25    ByteSizeOverflow {
26        /// Segment being translated.
27        context: &'static str,
28        /// Word count or word offset that overflowed when scaled by four.
29        words: usize,
30    },
31}
32
33impl fmt::Display for CudaEGraphDeviceUploadError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Image(error) => error.fmt(f),
37            Self::ByteSizeOverflow { context, words } => write!(
38                f,
39                "CUDA e-graph upload {context} word value {words} overflows byte addressing. Fix: shard the e-graph upload before staging."
40            ),
41        }
42    }
43}
44
45impl std::error::Error for CudaEGraphDeviceUploadError {}
46
47impl From<GpuEGraphDeviceImageError> for CudaEGraphDeviceUploadError {
48    fn from(error: GpuEGraphDeviceImageError) -> Self {
49        Self::Image(error)
50    }
51}
52
53/// Byte span inside the single CUDA e-graph upload slab.
54#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
55pub struct CudaEGraphDeviceByteSpan {
56    offset: usize,
57    byte_len: usize,
58}
59
60impl CudaEGraphDeviceByteSpan {
61    const fn new(offset: usize, byte_len: usize) -> Self {
62        Self { offset, byte_len }
63    }
64
65    /// Byte offset from the start of the CUDA upload slab.
66    #[must_use]
67    pub const fn offset(&self) -> usize {
68        self.offset
69    }
70
71    /// Number of bytes in the span.
72    #[must_use]
73    pub const fn byte_len(&self) -> usize {
74        self.byte_len
75    }
76
77    /// `true` iff this span contains no bytes.
78    #[must_use]
79    pub const fn is_empty(&self) -> bool {
80        self.byte_len == 0
81    }
82}
83
84/// CUDA byte layout for a packed e-graph device image.
85#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
86pub struct CudaEGraphDeviceByteLayout {
87    row_count: usize,
88    child_count: usize,
89    eclass_group_count: usize,
90    row_eclass_ids: CudaEGraphDeviceByteSpan,
91    row_language_op_ids: CudaEGraphDeviceByteSpan,
92    row_children_offsets: CudaEGraphDeviceByteSpan,
93    row_children_lens: CudaEGraphDeviceByteSpan,
94    row_signatures: CudaEGraphDeviceByteSpan,
95    children: CudaEGraphDeviceByteSpan,
96    group_eclass_ids: CudaEGraphDeviceByteSpan,
97    group_offsets: CudaEGraphDeviceByteSpan,
98    group_rows: CudaEGraphDeviceByteSpan,
99}
100
101impl CudaEGraphDeviceByteLayout {
102    /// Number of snapshot rows in the upload image.
103    #[must_use]
104    pub const fn row_count(&self) -> usize {
105        self.row_count
106    }
107
108    /// Number of child references in the upload image.
109    #[must_use]
110    pub const fn child_count(&self) -> usize {
111        self.child_count
112    }
113
114    /// Number of e-class row groups in the upload image.
115    #[must_use]
116    pub const fn eclass_group_count(&self) -> usize {
117        self.eclass_group_count
118    }
119
120    /// Byte span containing one e-class id per row.
121    #[must_use]
122    pub const fn row_eclass_ids(&self) -> CudaEGraphDeviceByteSpan {
123        self.row_eclass_ids
124    }
125
126    /// Byte span containing one language op id per row.
127    #[must_use]
128    pub const fn row_language_op_ids(&self) -> CudaEGraphDeviceByteSpan {
129        self.row_language_op_ids
130    }
131
132    /// Byte span containing one child-column offset per row.
133    #[must_use]
134    pub const fn row_children_offsets(&self) -> CudaEGraphDeviceByteSpan {
135        self.row_children_offsets
136    }
137
138    /// Byte span containing one child count per row.
139    #[must_use]
140    pub const fn row_children_lens(&self) -> CudaEGraphDeviceByteSpan {
141        self.row_children_lens
142    }
143
144    /// Byte span containing one structural signature per row.
145    #[must_use]
146    pub const fn row_signatures(&self) -> CudaEGraphDeviceByteSpan {
147        self.row_signatures
148    }
149
150    /// Byte span containing the flat child e-class column.
151    #[must_use]
152    pub const fn children(&self) -> CudaEGraphDeviceByteSpan {
153        self.children
154    }
155
156    /// Byte span containing sorted grouped e-class ids.
157    #[must_use]
158    pub const fn group_eclass_ids(&self) -> CudaEGraphDeviceByteSpan {
159        self.group_eclass_ids
160    }
161
162    /// Byte span containing prefix offsets into [`Self::group_rows`].
163    #[must_use]
164    pub const fn group_offsets(&self) -> CudaEGraphDeviceByteSpan {
165        self.group_offsets
166    }
167
168    /// Byte span containing row indices grouped by e-class.
169    #[must_use]
170    pub const fn group_rows(&self) -> CudaEGraphDeviceByteSpan {
171        self.group_rows
172    }
173}
174
175/// CUDA upload plan for a validated foundation e-graph device image.
176#[derive(Clone, Debug, Eq, PartialEq)]
177pub struct CudaEGraphDeviceUploadPlan {
178    image: GpuEGraphDeviceImage,
179    byte_layout: CudaEGraphDeviceByteLayout,
180    byte_len: usize,
181}
182
183impl CudaEGraphDeviceUploadPlan {
184    /// Packed u32 words to copy into CUDA-pinned staging memory.
185    #[must_use]
186    pub fn words(&self) -> &[u32] {
187        self.image.words()
188    }
189
190    /// Foundation-owned logical image.
191    #[must_use]
192    pub const fn image(&self) -> &GpuEGraphDeviceImage {
193        &self.image
194    }
195
196    /// CUDA byte layout for kernel parameters.
197    #[must_use]
198    pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
199        self.byte_layout
200    }
201
202    /// Total number of bytes required for the CUDA upload slab.
203    #[must_use]
204    pub const fn byte_len(&self) -> usize {
205        self.byte_len
206    }
207}
208
209/// Borrowed CUDA upload plan for an already-packed foundation e-graph image.
210///
211/// This is the release hot path when the caller still needs to inspect the
212/// packed image for launch planning after upload. It avoids cloning the full
213/// packed slab just to satisfy the owned upload-plan API.
214#[derive(Clone, Copy, Debug, Eq, PartialEq)]
215pub struct CudaEGraphDeviceBorrowedUploadPlan<'a> {
216    words: &'a [u32],
217    byte_layout: CudaEGraphDeviceByteLayout,
218    byte_len: usize,
219}
220
221impl<'a> CudaEGraphDeviceBorrowedUploadPlan<'a> {
222    /// Packed u32 words to copy into CUDA-pinned staging memory.
223    #[must_use]
224    pub const fn words(&self) -> &'a [u32] {
225        self.words
226    }
227
228    /// CUDA byte layout for kernel parameters.
229    #[must_use]
230    pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
231        self.byte_layout
232    }
233
234    /// Total number of bytes required for the CUDA upload slab.
235    #[must_use]
236    pub const fn byte_len(&self) -> usize {
237        self.byte_len
238    }
239}
240
241/// CUDA-resident e-graph device image plus the byte layout kernels need.
242#[derive(Clone, Copy, Debug, Eq, PartialEq)]
243pub struct CudaResidentEGraphDeviceImage {
244    handle: CudaResidentBuffer,
245    byte_layout: CudaEGraphDeviceByteLayout,
246    byte_len: usize,
247    word_count: usize,
248}
249
250impl CudaResidentEGraphDeviceImage {
251    /// Resident CUDA buffer containing the packed u32 e-graph image.
252    #[must_use]
253    pub const fn handle(&self) -> CudaResidentBuffer {
254        self.handle
255    }
256
257    /// CUDA byte layout for kernel parameters.
258    #[must_use]
259    pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
260        self.byte_layout
261    }
262
263    /// Total bytes uploaded to the resident image buffer.
264    #[must_use]
265    pub const fn byte_len(&self) -> usize {
266        self.byte_len
267    }
268
269    /// Total u32 words uploaded to the resident image buffer.
270    #[must_use]
271    pub const fn word_count(&self) -> usize {
272        self.word_count
273    }
274}
275
276/// Checked kernel-facing pointer view of a CUDA-resident e-graph image.
277#[derive(Clone, Copy, Debug, Eq, PartialEq)]
278pub struct CudaEGraphDeviceKernelView {
279    base_ptr: u64,
280    byte_len: usize,
281    row_count: usize,
282    child_count: usize,
283    eclass_group_count: usize,
284    row_eclass_ids_ptr: u64,
285    row_language_op_ids_ptr: u64,
286    row_children_offsets_ptr: u64,
287    row_children_lens_ptr: u64,
288    row_signatures_ptr: u64,
289    children_ptr: u64,
290    group_eclass_ids_ptr: u64,
291    group_offsets_ptr: u64,
292    group_rows_ptr: u64,
293}
294
295impl CudaEGraphDeviceKernelView {
296    /// Build a kernel view from a base pointer, byte length, and byte layout.
297    ///
298    /// # Errors
299    ///
300    /// Returns [`BackendError`] if any layout span points outside the image or
301    /// if pointer arithmetic overflows.
302    pub fn from_checked_parts(
303        base_ptr: u64,
304        byte_len: usize,
305        layout: CudaEGraphDeviceByteLayout,
306    ) -> Result<Self, BackendError> {
307        Ok(Self {
308            base_ptr,
309            byte_len,
310            row_count: layout.row_count(),
311            child_count: layout.child_count(),
312            eclass_group_count: layout.eclass_group_count(),
313            row_eclass_ids_ptr: device_span_ptr(
314                base_ptr,
315                layout.row_eclass_ids(),
316                byte_len,
317                "row eclass ids",
318            )?,
319            row_language_op_ids_ptr: device_span_ptr(
320                base_ptr,
321                layout.row_language_op_ids(),
322                byte_len,
323                "row language op ids",
324            )?,
325            row_children_offsets_ptr: device_span_ptr(
326                base_ptr,
327                layout.row_children_offsets(),
328                byte_len,
329                "row child offsets",
330            )?,
331            row_children_lens_ptr: device_span_ptr(
332                base_ptr,
333                layout.row_children_lens(),
334                byte_len,
335                "row child lengths",
336            )?,
337            row_signatures_ptr: device_span_ptr(
338                base_ptr,
339                layout.row_signatures(),
340                byte_len,
341                "row signatures",
342            )?,
343            children_ptr: device_span_ptr(base_ptr, layout.children(), byte_len, "children")?,
344            group_eclass_ids_ptr: device_span_ptr(
345                base_ptr,
346                layout.group_eclass_ids(),
347                byte_len,
348                "group eclass ids",
349            )?,
350            group_offsets_ptr: device_span_ptr(
351                base_ptr,
352                layout.group_offsets(),
353                byte_len,
354                "group offsets",
355            )?,
356            group_rows_ptr: device_span_ptr(base_ptr, layout.group_rows(), byte_len, "group rows")?,
357        })
358    }
359
360    /// Base device pointer of the packed e-graph image.
361    #[must_use]
362    pub const fn base_ptr(&self) -> u64 {
363        self.base_ptr
364    }
365
366    /// Total byte length of the resident image.
367    #[must_use]
368    pub const fn byte_len(&self) -> usize {
369        self.byte_len
370    }
371
372    /// Number of e-graph rows.
373    #[must_use]
374    pub const fn row_count(&self) -> usize {
375        self.row_count
376    }
377
378    /// Number of child e-class references.
379    #[must_use]
380    pub const fn child_count(&self) -> usize {
381        self.child_count
382    }
383
384    /// Number of grouped e-class row spans.
385    #[must_use]
386    pub const fn eclass_group_count(&self) -> usize {
387        self.eclass_group_count
388    }
389
390    /// Device pointer to the row e-class id column.
391    #[must_use]
392    pub const fn row_eclass_ids_ptr(&self) -> u64 {
393        self.row_eclass_ids_ptr
394    }
395
396    /// Device pointer to the row language-op id column.
397    #[must_use]
398    pub const fn row_language_op_ids_ptr(&self) -> u64 {
399        self.row_language_op_ids_ptr
400    }
401
402    /// Device pointer to the row child-offset column.
403    #[must_use]
404    pub const fn row_children_offsets_ptr(&self) -> u64 {
405        self.row_children_offsets_ptr
406    }
407
408    /// Device pointer to the row child-length column.
409    #[must_use]
410    pub const fn row_children_lens_ptr(&self) -> u64 {
411        self.row_children_lens_ptr
412    }
413
414    /// Device pointer to the row structural-signature column.
415    #[must_use]
416    pub const fn row_signatures_ptr(&self) -> u64 {
417        self.row_signatures_ptr
418    }
419
420    /// Device pointer to the flat child e-class column.
421    #[must_use]
422    pub const fn children_ptr(&self) -> u64 {
423        self.children_ptr
424    }
425
426    /// Device pointer to sorted grouped e-class ids.
427    #[must_use]
428    pub const fn group_eclass_ids_ptr(&self) -> u64 {
429        self.group_eclass_ids_ptr
430    }
431
432    /// Device pointer to group prefix offsets.
433    #[must_use]
434    pub const fn group_offsets_ptr(&self) -> u64 {
435        self.group_offsets_ptr
436    }
437
438    /// Device pointer to row indices grouped by e-class.
439    #[must_use]
440    pub const fn group_rows_ptr(&self) -> u64 {
441        self.group_rows_ptr
442    }
443}
444
445impl CudaBackend {
446    /// Pack and upload an e-graph snapshot into one CUDA-resident buffer.
447    ///
448    /// # Errors
449    ///
450    /// Returns [`BackendError`] if the snapshot is malformed, the image cannot
451    /// be represented as CUDA byte spans, or resident allocation/upload fails.
452    pub fn upload_egraph_device_image(
453        &self,
454        snapshot: &GpuEGraphSnapshot,
455    ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
456        let plan = plan_cuda_egraph_device_upload(snapshot)
457            .map_err(cuda_egraph_upload_plan_to_backend_error)?;
458        self.upload_egraph_device_image_plan(plan)
459    }
460
461    /// Upload an already-planned e-graph image into one CUDA-resident buffer.
462    ///
463    /// # Errors
464    ///
465    /// Returns [`BackendError`] if host-byte staging, resident allocation, or
466    /// resident upload fails.
467    pub fn upload_egraph_device_image_plan(
468        &self,
469        plan: CudaEGraphDeviceUploadPlan,
470    ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
471        self.upload_egraph_device_image_words(plan.words(), plan.byte_layout(), plan.byte_len())
472    }
473
474    /// Upload a borrowed e-graph image plan into one CUDA-resident buffer.
475    ///
476    /// # Errors
477    ///
478    /// Returns [`BackendError`] if host-byte staging, resident allocation, or
479    /// resident upload fails.
480    pub fn upload_egraph_device_image_borrowed_plan(
481        &self,
482        plan: CudaEGraphDeviceBorrowedUploadPlan<'_>,
483    ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
484        self.upload_egraph_device_image_words(plan.words(), plan.byte_layout(), plan.byte_len())
485    }
486
487    fn upload_egraph_device_image_words(
488        &self,
489        words: &[u32],
490        byte_layout: CudaEGraphDeviceByteLayout,
491        byte_len: usize,
492    ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
493        let word_count = words.len();
494        let handle = self.allocate_resident(byte_len)?;
495        if let Err(error) = upload_egraph_words_to_resident(self, handle, words) {
496            let _ = self.free_resident(handle);
497            return Err(error);
498        }
499        Ok(CudaResidentEGraphDeviceImage {
500            handle,
501            byte_layout,
502            byte_len,
503            word_count,
504        })
505    }
506
507    /// Resolve a resident e-graph image into checked kernel pointer metadata.
508    ///
509    /// # Errors
510    ///
511    /// Returns [`BackendError`] if the resident handle is not owned by this
512    /// backend or if any byte span would point outside the resident image.
513    pub fn egraph_device_kernel_view(
514        &self,
515        image: CudaResidentEGraphDeviceImage,
516    ) -> Result<CudaEGraphDeviceKernelView, BackendError> {
517        let base_ptr = self.resident_device_ptr(image.handle())?;
518        CudaEGraphDeviceKernelView::from_checked_parts(
519            base_ptr,
520            image.byte_len(),
521            image.byte_layout(),
522        )
523    }
524}
525
526/// Build a CUDA upload plan directly from a foundation e-graph snapshot.
527///
528/// # Errors
529///
530/// Returns [`CudaEGraphDeviceUploadError`] if the snapshot cannot be packed or
531/// if the packed word spans overflow host byte addressing.
532
533pub fn plan_cuda_egraph_device_upload(
534    snapshot: &GpuEGraphSnapshot,
535) -> Result<CudaEGraphDeviceUploadPlan, CudaEGraphDeviceUploadError> {
536    plan_cuda_egraph_device_upload_from_image(snapshot.try_pack_device_image()?)
537}
538
539/// Build a CUDA upload plan from an already-packed foundation image.
540///
541/// # Errors
542///
543/// Returns [`CudaEGraphDeviceUploadError`] if a packed word span overflows host
544/// byte addressing.
545pub fn plan_cuda_egraph_device_upload_from_image(
546    image: GpuEGraphDeviceImage,
547) -> Result<CudaEGraphDeviceUploadPlan, CudaEGraphDeviceUploadError> {
548    let borrowed = plan_cuda_egraph_device_upload_from_image_ref(&image)?;
549    let byte_layout = borrowed.byte_layout();
550    let byte_len = borrowed.byte_len();
551    Ok(CudaEGraphDeviceUploadPlan {
552        image,
553        byte_layout,
554        byte_len,
555    })
556}
557
558/// Build a borrowed CUDA upload plan from an already-packed foundation image.
559///
560/// # Errors
561///
562/// Returns [`CudaEGraphDeviceUploadError`] if a packed word span overflows host
563/// byte addressing.
564pub fn plan_cuda_egraph_device_upload_from_image_ref(
565    image: &GpuEGraphDeviceImage,
566) -> Result<CudaEGraphDeviceBorrowedUploadPlan<'_>, CudaEGraphDeviceUploadError> {
567    let layout = image.layout();
568    let byte_layout = cuda_byte_layout(layout)?;
569    let byte_len = checked_words_to_bytes(image.words().len(), "total upload length")?;
570    Ok(CudaEGraphDeviceBorrowedUploadPlan {
571        words: image.words(),
572        byte_layout,
573        byte_len,
574    })
575}
576
577fn cuda_byte_layout(
578    layout: GpuEGraphDeviceLayout,
579) -> Result<CudaEGraphDeviceByteLayout, CudaEGraphDeviceUploadError> {
580    Ok(CudaEGraphDeviceByteLayout {
581        row_count: layout.row_count(),
582        child_count: layout.child_count(),
583        eclass_group_count: layout.eclass_group_count(),
584        row_eclass_ids: byte_span(layout.row_eclass_ids(), "row eclass ids")?,
585        row_language_op_ids: byte_span(layout.row_language_op_ids(), "row language op ids")?,
586        row_children_offsets: byte_span(layout.row_children_offsets(), "row child offsets")?,
587        row_children_lens: byte_span(layout.row_children_lens(), "row child lengths")?,
588        row_signatures: byte_span(layout.row_signatures(), "row signatures")?,
589        children: byte_span(layout.children(), "children")?,
590        group_eclass_ids: byte_span(layout.group_eclass_ids(), "group eclass ids")?,
591        group_offsets: byte_span(layout.group_offsets(), "group offsets")?,
592        group_rows: byte_span(layout.group_rows(), "group rows")?,
593    })
594}
595
596fn byte_span(
597    span: GpuEGraphDeviceSpan,
598    context: &'static str,
599) -> Result<CudaEGraphDeviceByteSpan, CudaEGraphDeviceUploadError> {
600    Ok(CudaEGraphDeviceByteSpan::new(
601        checked_words_to_bytes(span.offset(), context)?,
602        checked_words_to_bytes(span.len(), context)?,
603    ))
604}
605
606fn checked_words_to_bytes(
607    words: usize,
608    context: &'static str,
609) -> Result<usize, CudaEGraphDeviceUploadError> {
610    words
611        .checked_mul(std::mem::size_of::<u32>())
612        .ok_or(CudaEGraphDeviceUploadError::ByteSizeOverflow { context, words })
613}
614
615fn upload_egraph_words_to_resident(
616    backend: &CudaBackend,
617    handle: CudaResidentBuffer,
618    words: &[u32],
619) -> Result<(), BackendError> {
620    #[cfg(target_endian = "little")]
621    {
622        backend.upload_resident(handle, bytemuck::cast_slice(words))
623    }
624    #[cfg(not(target_endian = "little"))]
625    {
626        let bytes = egraph_words_to_le_bytes(words)?;
627        backend.upload_resident(handle, &bytes)
628    }
629}
630
631#[cfg(not(target_endian = "little"))]
632fn egraph_words_to_le_bytes(words: &[u32]) -> Result<Vec<u8>, BackendError> {
633    let byte_len = checked_words_to_bytes(words.len(), "resident egraph upload words")
634        .map_err(cuda_egraph_upload_plan_to_backend_error)?;
635    let mut bytes = Vec::new();
636    bytes
637        .try_reserve_exact(byte_len)
638        .map_err(|error| BackendError::InvalidProgram {
639            fix: format!(
640                "Fix: CUDA e-graph resident upload could not reserve {byte_len} host byte(s): {error}. Shard the e-graph image before upload."
641            ),
642        })?;
643    for word in words {
644        bytes.extend_from_slice(&word.to_le_bytes());
645    }
646    Ok(bytes)
647}
648
649fn cuda_egraph_upload_plan_to_backend_error(error: CudaEGraphDeviceUploadError) -> BackendError {
650    BackendError::InvalidProgram {
651        fix: error.to_string(),
652    }
653}
654
655fn device_span_ptr(
656    base_ptr: u64,
657    span: CudaEGraphDeviceByteSpan,
658    image_byte_len: usize,
659    context: &'static str,
660) -> Result<u64, BackendError> {
661    let end = span
662        .offset()
663        .checked_add(span.byte_len())
664        .ok_or_else(|| BackendError::InvalidProgram {
665            fix: format!(
666                "Fix: CUDA e-graph kernel view span `{context}` overflows usize. Rebuild or shard the image before launch."
667            ),
668        })?;
669    if end > image_byte_len {
670        return Err(BackendError::InvalidProgram {
671            fix: format!(
672                "Fix: CUDA e-graph kernel view span `{context}` points to bytes [{}..{end}) but resident image has {image_byte_len} bytes.",
673                span.offset()
674            ),
675        });
676    }
677    base_ptr
678        .checked_add(CUDA_NUMERIC.usize_to_u64(
679            span.offset(),
680            "CUDA e-graph kernel view byte offset",
681        )?)
682        .ok_or_else(|| BackendError::InvalidProgram {
683            fix: format!(
684                "Fix: CUDA e-graph kernel view pointer arithmetic overflowed for span `{context}` at byte offset {}.",
685                span.offset()
686            ),
687        })
688}
689