1use std::fmt;
9
10use vyre_driver::BackendError;
11use vyre_foundation::optimizer::eqsat_gpu::{
12 GpuEGraphDeviceImage, GpuEGraphDeviceImageError, GpuEGraphDeviceLayout, GpuEGraphDeviceSpan,
13 GpuEGraphSnapshot,
14};
15
16use crate::backend::{CudaBackend, CudaResidentBuffer};
17use crate::numeric::CUDA_NUMERIC;
18
19#[derive(Clone, Debug, Eq, PartialEq)]
21pub enum CudaEGraphDeviceUploadError {
22 Image(GpuEGraphDeviceImageError),
24 ByteSizeOverflow {
26 context: &'static str,
28 words: usize,
30 },
31}
32
33impl fmt::Display for CudaEGraphDeviceUploadError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Image(error) => error.fmt(f),
37 Self::ByteSizeOverflow { context, words } => write!(
38 f,
39 "CUDA e-graph upload {context} word value {words} overflows byte addressing. Fix: shard the e-graph upload before staging."
40 ),
41 }
42 }
43}
44
45impl std::error::Error for CudaEGraphDeviceUploadError {}
46
47impl From<GpuEGraphDeviceImageError> for CudaEGraphDeviceUploadError {
48 fn from(error: GpuEGraphDeviceImageError) -> Self {
49 Self::Image(error)
50 }
51}
52
53#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
55pub struct CudaEGraphDeviceByteSpan {
56 offset: usize,
57 byte_len: usize,
58}
59
60impl CudaEGraphDeviceByteSpan {
61 const fn new(offset: usize, byte_len: usize) -> Self {
62 Self { offset, byte_len }
63 }
64
65 #[must_use]
67 pub const fn offset(&self) -> usize {
68 self.offset
69 }
70
71 #[must_use]
73 pub const fn byte_len(&self) -> usize {
74 self.byte_len
75 }
76
77 #[must_use]
79 pub const fn is_empty(&self) -> bool {
80 self.byte_len == 0
81 }
82}
83
84#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
86pub struct CudaEGraphDeviceByteLayout {
87 row_count: usize,
88 child_count: usize,
89 eclass_group_count: usize,
90 row_eclass_ids: CudaEGraphDeviceByteSpan,
91 row_language_op_ids: CudaEGraphDeviceByteSpan,
92 row_children_offsets: CudaEGraphDeviceByteSpan,
93 row_children_lens: CudaEGraphDeviceByteSpan,
94 row_signatures: CudaEGraphDeviceByteSpan,
95 children: CudaEGraphDeviceByteSpan,
96 group_eclass_ids: CudaEGraphDeviceByteSpan,
97 group_offsets: CudaEGraphDeviceByteSpan,
98 group_rows: CudaEGraphDeviceByteSpan,
99}
100
101impl CudaEGraphDeviceByteLayout {
102 #[must_use]
104 pub const fn row_count(&self) -> usize {
105 self.row_count
106 }
107
108 #[must_use]
110 pub const fn child_count(&self) -> usize {
111 self.child_count
112 }
113
114 #[must_use]
116 pub const fn eclass_group_count(&self) -> usize {
117 self.eclass_group_count
118 }
119
120 #[must_use]
122 pub const fn row_eclass_ids(&self) -> CudaEGraphDeviceByteSpan {
123 self.row_eclass_ids
124 }
125
126 #[must_use]
128 pub const fn row_language_op_ids(&self) -> CudaEGraphDeviceByteSpan {
129 self.row_language_op_ids
130 }
131
132 #[must_use]
134 pub const fn row_children_offsets(&self) -> CudaEGraphDeviceByteSpan {
135 self.row_children_offsets
136 }
137
138 #[must_use]
140 pub const fn row_children_lens(&self) -> CudaEGraphDeviceByteSpan {
141 self.row_children_lens
142 }
143
144 #[must_use]
146 pub const fn row_signatures(&self) -> CudaEGraphDeviceByteSpan {
147 self.row_signatures
148 }
149
150 #[must_use]
152 pub const fn children(&self) -> CudaEGraphDeviceByteSpan {
153 self.children
154 }
155
156 #[must_use]
158 pub const fn group_eclass_ids(&self) -> CudaEGraphDeviceByteSpan {
159 self.group_eclass_ids
160 }
161
162 #[must_use]
164 pub const fn group_offsets(&self) -> CudaEGraphDeviceByteSpan {
165 self.group_offsets
166 }
167
168 #[must_use]
170 pub const fn group_rows(&self) -> CudaEGraphDeviceByteSpan {
171 self.group_rows
172 }
173}
174
175#[derive(Clone, Debug, Eq, PartialEq)]
177pub struct CudaEGraphDeviceUploadPlan {
178 image: GpuEGraphDeviceImage,
179 byte_layout: CudaEGraphDeviceByteLayout,
180 byte_len: usize,
181}
182
183impl CudaEGraphDeviceUploadPlan {
184 #[must_use]
186 pub fn words(&self) -> &[u32] {
187 self.image.words()
188 }
189
190 #[must_use]
192 pub const fn image(&self) -> &GpuEGraphDeviceImage {
193 &self.image
194 }
195
196 #[must_use]
198 pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
199 self.byte_layout
200 }
201
202 #[must_use]
204 pub const fn byte_len(&self) -> usize {
205 self.byte_len
206 }
207}
208
209#[derive(Clone, Copy, Debug, Eq, PartialEq)]
215pub struct CudaEGraphDeviceBorrowedUploadPlan<'a> {
216 words: &'a [u32],
217 byte_layout: CudaEGraphDeviceByteLayout,
218 byte_len: usize,
219}
220
221impl<'a> CudaEGraphDeviceBorrowedUploadPlan<'a> {
222 #[must_use]
224 pub const fn words(&self) -> &'a [u32] {
225 self.words
226 }
227
228 #[must_use]
230 pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
231 self.byte_layout
232 }
233
234 #[must_use]
236 pub const fn byte_len(&self) -> usize {
237 self.byte_len
238 }
239}
240
241#[derive(Clone, Copy, Debug, Eq, PartialEq)]
243pub struct CudaResidentEGraphDeviceImage {
244 handle: CudaResidentBuffer,
245 byte_layout: CudaEGraphDeviceByteLayout,
246 byte_len: usize,
247 word_count: usize,
248}
249
250impl CudaResidentEGraphDeviceImage {
251 #[must_use]
253 pub const fn handle(&self) -> CudaResidentBuffer {
254 self.handle
255 }
256
257 #[must_use]
259 pub const fn byte_layout(&self) -> CudaEGraphDeviceByteLayout {
260 self.byte_layout
261 }
262
263 #[must_use]
265 pub const fn byte_len(&self) -> usize {
266 self.byte_len
267 }
268
269 #[must_use]
271 pub const fn word_count(&self) -> usize {
272 self.word_count
273 }
274}
275
276#[derive(Clone, Copy, Debug, Eq, PartialEq)]
278pub struct CudaEGraphDeviceKernelView {
279 base_ptr: u64,
280 byte_len: usize,
281 row_count: usize,
282 child_count: usize,
283 eclass_group_count: usize,
284 row_eclass_ids_ptr: u64,
285 row_language_op_ids_ptr: u64,
286 row_children_offsets_ptr: u64,
287 row_children_lens_ptr: u64,
288 row_signatures_ptr: u64,
289 children_ptr: u64,
290 group_eclass_ids_ptr: u64,
291 group_offsets_ptr: u64,
292 group_rows_ptr: u64,
293}
294
295impl CudaEGraphDeviceKernelView {
296 pub fn from_checked_parts(
303 base_ptr: u64,
304 byte_len: usize,
305 layout: CudaEGraphDeviceByteLayout,
306 ) -> Result<Self, BackendError> {
307 Ok(Self {
308 base_ptr,
309 byte_len,
310 row_count: layout.row_count(),
311 child_count: layout.child_count(),
312 eclass_group_count: layout.eclass_group_count(),
313 row_eclass_ids_ptr: device_span_ptr(
314 base_ptr,
315 layout.row_eclass_ids(),
316 byte_len,
317 "row eclass ids",
318 )?,
319 row_language_op_ids_ptr: device_span_ptr(
320 base_ptr,
321 layout.row_language_op_ids(),
322 byte_len,
323 "row language op ids",
324 )?,
325 row_children_offsets_ptr: device_span_ptr(
326 base_ptr,
327 layout.row_children_offsets(),
328 byte_len,
329 "row child offsets",
330 )?,
331 row_children_lens_ptr: device_span_ptr(
332 base_ptr,
333 layout.row_children_lens(),
334 byte_len,
335 "row child lengths",
336 )?,
337 row_signatures_ptr: device_span_ptr(
338 base_ptr,
339 layout.row_signatures(),
340 byte_len,
341 "row signatures",
342 )?,
343 children_ptr: device_span_ptr(base_ptr, layout.children(), byte_len, "children")?,
344 group_eclass_ids_ptr: device_span_ptr(
345 base_ptr,
346 layout.group_eclass_ids(),
347 byte_len,
348 "group eclass ids",
349 )?,
350 group_offsets_ptr: device_span_ptr(
351 base_ptr,
352 layout.group_offsets(),
353 byte_len,
354 "group offsets",
355 )?,
356 group_rows_ptr: device_span_ptr(base_ptr, layout.group_rows(), byte_len, "group rows")?,
357 })
358 }
359
360 #[must_use]
362 pub const fn base_ptr(&self) -> u64 {
363 self.base_ptr
364 }
365
366 #[must_use]
368 pub const fn byte_len(&self) -> usize {
369 self.byte_len
370 }
371
372 #[must_use]
374 pub const fn row_count(&self) -> usize {
375 self.row_count
376 }
377
378 #[must_use]
380 pub const fn child_count(&self) -> usize {
381 self.child_count
382 }
383
384 #[must_use]
386 pub const fn eclass_group_count(&self) -> usize {
387 self.eclass_group_count
388 }
389
390 #[must_use]
392 pub const fn row_eclass_ids_ptr(&self) -> u64 {
393 self.row_eclass_ids_ptr
394 }
395
396 #[must_use]
398 pub const fn row_language_op_ids_ptr(&self) -> u64 {
399 self.row_language_op_ids_ptr
400 }
401
402 #[must_use]
404 pub const fn row_children_offsets_ptr(&self) -> u64 {
405 self.row_children_offsets_ptr
406 }
407
408 #[must_use]
410 pub const fn row_children_lens_ptr(&self) -> u64 {
411 self.row_children_lens_ptr
412 }
413
414 #[must_use]
416 pub const fn row_signatures_ptr(&self) -> u64 {
417 self.row_signatures_ptr
418 }
419
420 #[must_use]
422 pub const fn children_ptr(&self) -> u64 {
423 self.children_ptr
424 }
425
426 #[must_use]
428 pub const fn group_eclass_ids_ptr(&self) -> u64 {
429 self.group_eclass_ids_ptr
430 }
431
432 #[must_use]
434 pub const fn group_offsets_ptr(&self) -> u64 {
435 self.group_offsets_ptr
436 }
437
438 #[must_use]
440 pub const fn group_rows_ptr(&self) -> u64 {
441 self.group_rows_ptr
442 }
443}
444
445impl CudaBackend {
446 pub fn upload_egraph_device_image(
453 &self,
454 snapshot: &GpuEGraphSnapshot,
455 ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
456 let plan = plan_cuda_egraph_device_upload(snapshot)
457 .map_err(cuda_egraph_upload_plan_to_backend_error)?;
458 self.upload_egraph_device_image_plan(plan)
459 }
460
461 pub fn upload_egraph_device_image_plan(
468 &self,
469 plan: CudaEGraphDeviceUploadPlan,
470 ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
471 self.upload_egraph_device_image_words(plan.words(), plan.byte_layout(), plan.byte_len())
472 }
473
474 pub fn upload_egraph_device_image_borrowed_plan(
481 &self,
482 plan: CudaEGraphDeviceBorrowedUploadPlan<'_>,
483 ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
484 self.upload_egraph_device_image_words(plan.words(), plan.byte_layout(), plan.byte_len())
485 }
486
487 fn upload_egraph_device_image_words(
488 &self,
489 words: &[u32],
490 byte_layout: CudaEGraphDeviceByteLayout,
491 byte_len: usize,
492 ) -> Result<CudaResidentEGraphDeviceImage, BackendError> {
493 let word_count = words.len();
494 let handle = self.allocate_resident(byte_len)?;
495 if let Err(error) = upload_egraph_words_to_resident(self, handle, words) {
496 let _ = self.free_resident(handle);
497 return Err(error);
498 }
499 Ok(CudaResidentEGraphDeviceImage {
500 handle,
501 byte_layout,
502 byte_len,
503 word_count,
504 })
505 }
506
507 pub fn egraph_device_kernel_view(
514 &self,
515 image: CudaResidentEGraphDeviceImage,
516 ) -> Result<CudaEGraphDeviceKernelView, BackendError> {
517 let base_ptr = self.resident_device_ptr(image.handle())?;
518 CudaEGraphDeviceKernelView::from_checked_parts(
519 base_ptr,
520 image.byte_len(),
521 image.byte_layout(),
522 )
523 }
524}
525
526pub fn plan_cuda_egraph_device_upload(
534 snapshot: &GpuEGraphSnapshot,
535) -> Result<CudaEGraphDeviceUploadPlan, CudaEGraphDeviceUploadError> {
536 plan_cuda_egraph_device_upload_from_image(snapshot.try_pack_device_image()?)
537}
538
539pub fn plan_cuda_egraph_device_upload_from_image(
546 image: GpuEGraphDeviceImage,
547) -> Result<CudaEGraphDeviceUploadPlan, CudaEGraphDeviceUploadError> {
548 let borrowed = plan_cuda_egraph_device_upload_from_image_ref(&image)?;
549 let byte_layout = borrowed.byte_layout();
550 let byte_len = borrowed.byte_len();
551 Ok(CudaEGraphDeviceUploadPlan {
552 image,
553 byte_layout,
554 byte_len,
555 })
556}
557
558pub fn plan_cuda_egraph_device_upload_from_image_ref(
565 image: &GpuEGraphDeviceImage,
566) -> Result<CudaEGraphDeviceBorrowedUploadPlan<'_>, CudaEGraphDeviceUploadError> {
567 let layout = image.layout();
568 let byte_layout = cuda_byte_layout(layout)?;
569 let byte_len = checked_words_to_bytes(image.words().len(), "total upload length")?;
570 Ok(CudaEGraphDeviceBorrowedUploadPlan {
571 words: image.words(),
572 byte_layout,
573 byte_len,
574 })
575}
576
577fn cuda_byte_layout(
578 layout: GpuEGraphDeviceLayout,
579) -> Result<CudaEGraphDeviceByteLayout, CudaEGraphDeviceUploadError> {
580 Ok(CudaEGraphDeviceByteLayout {
581 row_count: layout.row_count(),
582 child_count: layout.child_count(),
583 eclass_group_count: layout.eclass_group_count(),
584 row_eclass_ids: byte_span(layout.row_eclass_ids(), "row eclass ids")?,
585 row_language_op_ids: byte_span(layout.row_language_op_ids(), "row language op ids")?,
586 row_children_offsets: byte_span(layout.row_children_offsets(), "row child offsets")?,
587 row_children_lens: byte_span(layout.row_children_lens(), "row child lengths")?,
588 row_signatures: byte_span(layout.row_signatures(), "row signatures")?,
589 children: byte_span(layout.children(), "children")?,
590 group_eclass_ids: byte_span(layout.group_eclass_ids(), "group eclass ids")?,
591 group_offsets: byte_span(layout.group_offsets(), "group offsets")?,
592 group_rows: byte_span(layout.group_rows(), "group rows")?,
593 })
594}
595
596fn byte_span(
597 span: GpuEGraphDeviceSpan,
598 context: &'static str,
599) -> Result<CudaEGraphDeviceByteSpan, CudaEGraphDeviceUploadError> {
600 Ok(CudaEGraphDeviceByteSpan::new(
601 checked_words_to_bytes(span.offset(), context)?,
602 checked_words_to_bytes(span.len(), context)?,
603 ))
604}
605
606fn checked_words_to_bytes(
607 words: usize,
608 context: &'static str,
609) -> Result<usize, CudaEGraphDeviceUploadError> {
610 words
611 .checked_mul(std::mem::size_of::<u32>())
612 .ok_or(CudaEGraphDeviceUploadError::ByteSizeOverflow { context, words })
613}
614
615fn upload_egraph_words_to_resident(
616 backend: &CudaBackend,
617 handle: CudaResidentBuffer,
618 words: &[u32],
619) -> Result<(), BackendError> {
620 #[cfg(target_endian = "little")]
621 {
622 backend.upload_resident(handle, bytemuck::cast_slice(words))
623 }
624 #[cfg(not(target_endian = "little"))]
625 {
626 let bytes = egraph_words_to_le_bytes(words)?;
627 backend.upload_resident(handle, &bytes)
628 }
629}
630
631#[cfg(not(target_endian = "little"))]
632fn egraph_words_to_le_bytes(words: &[u32]) -> Result<Vec<u8>, BackendError> {
633 let byte_len = checked_words_to_bytes(words.len(), "resident egraph upload words")
634 .map_err(cuda_egraph_upload_plan_to_backend_error)?;
635 let mut bytes = Vec::new();
636 bytes
637 .try_reserve_exact(byte_len)
638 .map_err(|error| BackendError::InvalidProgram {
639 fix: format!(
640 "Fix: CUDA e-graph resident upload could not reserve {byte_len} host byte(s): {error}. Shard the e-graph image before upload."
641 ),
642 })?;
643 for word in words {
644 bytes.extend_from_slice(&word.to_le_bytes());
645 }
646 Ok(bytes)
647}
648
649fn cuda_egraph_upload_plan_to_backend_error(error: CudaEGraphDeviceUploadError) -> BackendError {
650 BackendError::InvalidProgram {
651 fix: error.to_string(),
652 }
653}
654
655fn device_span_ptr(
656 base_ptr: u64,
657 span: CudaEGraphDeviceByteSpan,
658 image_byte_len: usize,
659 context: &'static str,
660) -> Result<u64, BackendError> {
661 let end = span
662 .offset()
663 .checked_add(span.byte_len())
664 .ok_or_else(|| BackendError::InvalidProgram {
665 fix: format!(
666 "Fix: CUDA e-graph kernel view span `{context}` overflows usize. Rebuild or shard the image before launch."
667 ),
668 })?;
669 if end > image_byte_len {
670 return Err(BackendError::InvalidProgram {
671 fix: format!(
672 "Fix: CUDA e-graph kernel view span `{context}` points to bytes [{}..{end}) but resident image has {image_byte_len} bytes.",
673 span.offset()
674 ),
675 });
676 }
677 base_ptr
678 .checked_add(CUDA_NUMERIC.usize_to_u64(
679 span.offset(),
680 "CUDA e-graph kernel view byte offset",
681 )?)
682 .ok_or_else(|| BackendError::InvalidProgram {
683 fix: format!(
684 "Fix: CUDA e-graph kernel view pointer arithmetic overflowed for span `{context}` at byte offset {}.",
685 span.offset()
686 ),
687 })
688}
689