1mod persistent_handles;
4mod readback_dispatch;
5mod types;
6
7use super::builder::{build_program_jit_slots, build_program_sharded_slots_shared};
8use super::handlers::OpcodeHandler;
9use super::io;
10use super::planner::MegakernelLaunchGeometry;
11use super::protocol;
12use super::protocol_api::{validate_control_bytes, validate_debug_log_bytes};
13use super::recovery::{
14 backend_error_indicates_device_loss, recover_compiled_pipeline, MegakernelRecoveryDecision,
15 MegakernelRecoveryPolicy,
16};
17use super::staging_reserve::reserve_vec_capacity;
18use crate::PipelineError;
19use arc_swap::ArcSwap;
20use std::sync::Arc;
21use std::time::Instant;
22use vyre_driver::backend::{
23 CompiledPipeline, DispatchConfig, OutputBuffers, Resource, VyreBackend,
24};
25use vyre_foundation::ir::Program;
26
27pub use types::{
28 MegakernelBatchDispatchOutput, MegakernelDispatchOutput, MegakernelDispatchStats,
29 MegakernelResidentBatchScratch, MegakernelResidentHandles,
30};
31
32pub struct Megakernel {
38 backend: Arc<dyn VyreBackend>,
39 pipeline: ArcSwap<PipelineSlot>,
40 pipeline_id: String,
41 program: Arc<Program>,
42 has_grid_sync: bool,
43 empty_io_queue_bytes: Arc<[u8]>,
44 slot_count: u32,
45 workgroup_size_x: u32,
46 recovery_policy: MegakernelRecoveryPolicy,
47}
48
49struct PipelineSlot {
50 inner: Arc<dyn CompiledPipeline>,
51}
52
53impl Megakernel {
54 pub fn bootstrap(backend: Arc<dyn VyreBackend>) -> Result<Self, PipelineError> {
60 Self::bootstrap_sharded(backend, 256, 256, Vec::new())
61 }
62
63 pub fn bootstrap_with_opcodes(
69 backend: Arc<dyn VyreBackend>,
70 opcodes: Vec<OpcodeHandler>,
71 ) -> Result<Self, PipelineError> {
72 Self::bootstrap_sharded(backend, 256, 256, opcodes)
73 }
74
75 pub fn worker_groups_for_geometry(
82 slot_count: u32,
83 workgroup_size_x: u32,
84 ) -> Result<u32, PipelineError> {
85 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
86 Ok(slot_count / workgroup_size_x)
87 }
88
89 pub fn bootstrap_sharded(
96 backend: Arc<dyn VyreBackend>,
97 slot_count: u32,
98 workgroup_size_x: u32,
99 opcodes: Vec<OpcodeHandler>,
100 ) -> Result<Self, PipelineError> {
101 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
102 let program = build_program_sharded_slots_shared(workgroup_size_x, slot_count, &opcodes);
103 Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, program)
104 }
105
106 pub fn bootstrap_jit(
112 backend: Arc<dyn VyreBackend>,
113 slot_count: u32,
114 workgroup_size_x: u32,
115 payload_processor: &[vyre_foundation::ir::Node],
116 ) -> Result<Self, PipelineError> {
117 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
118 let program = build_program_jit_slots(workgroup_size_x, slot_count, payload_processor);
119 Self::compile_bootstrap(backend, slot_count, workgroup_size_x, program)
120 }
121
122 fn compile_bootstrap(
123 backend: Arc<dyn VyreBackend>,
124 slot_count: u32,
125 workgroup_size_x: u32,
126 program: Program,
127 ) -> Result<Self, PipelineError> {
128 Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, Arc::new(program))
129 }
130
131 fn compile_bootstrap_shared(
132 backend: Arc<dyn VyreBackend>,
133 slot_count: u32,
134 workgroup_size_x: u32,
135 program: Arc<Program>,
136 ) -> Result<Self, PipelineError> {
137 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
138 let config = DispatchConfig::default();
139 let pipeline = vyre_driver::pipeline::compile_shared(
140 Arc::clone(&backend),
141 Arc::clone(&program),
142 &config,
143 )?;
144 let pipeline_id = pipeline.id().to_string();
145 let has_grid_sync = vyre_driver::grid_sync::contains_grid_sync(&program);
146 let empty_io_queue_bytes =
147 Arc::<[u8]>::from(io::try_encode_empty_io_queue(io::IO_SLOT_COUNT)?.into_boxed_slice());
148 Ok(Self {
149 backend,
150 pipeline: ArcSwap::from(Arc::new(PipelineSlot { inner: pipeline })),
151 pipeline_id,
152 program,
153 has_grid_sync,
154 empty_io_queue_bytes,
155 slot_count,
156 workgroup_size_x,
157 recovery_policy: MegakernelRecoveryPolicy::default(),
158 })
159 }
160
161 pub fn dispatch(
168 &self,
169 control_bytes: Vec<u8>,
170 ring_bytes: Vec<u8>,
171 debug_log_bytes: Vec<u8>,
172 ) -> Result<Vec<Vec<u8>>, PipelineError> {
173 self.dispatch_borrowed(&control_bytes, &ring_bytes, &debug_log_bytes)
174 }
175
176 pub fn dispatch_borrowed(
183 &self,
184 control_bytes: &[u8],
185 ring_bytes: &[u8],
186 debug_log_bytes: &[u8],
187 ) -> Result<Vec<Vec<u8>>, PipelineError> {
188 Ok(self
189 .dispatch_borrowed_observed(control_bytes, ring_bytes, debug_log_bytes)?
190 .buffers)
191 }
192
193 pub fn dispatch_observed(
199 &self,
200 control_bytes: Vec<u8>,
201 ring_bytes: Vec<u8>,
202 debug_log_bytes: Vec<u8>,
203 ) -> Result<MegakernelDispatchOutput, PipelineError> {
204 self.dispatch_with_io_queue_borrowed_observed(
205 &control_bytes,
206 &ring_bytes,
207 &debug_log_bytes,
208 &self.empty_io_queue_bytes,
209 )
210 }
211
212 pub fn dispatch_borrowed_observed(
219 &self,
220 control_bytes: &[u8],
221 ring_bytes: &[u8],
222 debug_log_bytes: &[u8],
223 ) -> Result<MegakernelDispatchOutput, PipelineError> {
224 self.dispatch_with_io_queue_borrowed_observed(
225 control_bytes,
226 ring_bytes,
227 debug_log_bytes,
228 &self.empty_io_queue_bytes,
229 )
230 }
231
232 pub fn dispatch_with_io_queue(
239 &self,
240 control_bytes: Vec<u8>,
241 ring_bytes: Vec<u8>,
242 debug_log_bytes: Vec<u8>,
243 io_queue_bytes: Vec<u8>,
244 ) -> Result<Vec<Vec<u8>>, PipelineError> {
245 self.dispatch_with_io_queue_borrowed(
246 &control_bytes,
247 &ring_bytes,
248 &debug_log_bytes,
249 &io_queue_bytes,
250 )
251 }
252
253 pub fn dispatch_with_io_queue_borrowed(
259 &self,
260 control_bytes: &[u8],
261 ring_bytes: &[u8],
262 debug_log_bytes: &[u8],
263 io_queue_bytes: &[u8],
264 ) -> Result<Vec<Vec<u8>>, PipelineError> {
265 Ok(self
266 .dispatch_with_io_queue_borrowed_observed(
267 control_bytes,
268 ring_bytes,
269 debug_log_bytes,
270 io_queue_bytes,
271 )?
272 .buffers)
273 }
274
275 pub fn dispatch_with_io_queue_observed(
281 &self,
282 control_bytes: Vec<u8>,
283 ring_bytes: Vec<u8>,
284 debug_log_bytes: Vec<u8>,
285 io_queue_bytes: Vec<u8>,
286 ) -> Result<MegakernelDispatchOutput, PipelineError> {
287 self.dispatch_with_io_queue_borrowed_observed(
288 &control_bytes,
289 &ring_bytes,
290 &debug_log_bytes,
291 &io_queue_bytes,
292 )
293 }
294
295 pub fn dispatch_with_io_queue_borrowed_observed(
302 &self,
303 control_bytes: &[u8],
304 ring_bytes: &[u8],
305 debug_log_bytes: &[u8],
306 io_queue_bytes: &[u8],
307 ) -> Result<MegakernelDispatchOutput, PipelineError> {
308 let mut buffers = Vec::new();
309 reserve_output_shell(
310 &mut buffers,
311 MegakernelResidentHandles::ABI_RESOURCE_COUNT,
312 "borrowed megakernel output shell",
313 )?;
314 let stats = self.dispatch_with_io_queue_borrowed_into(
315 control_bytes,
316 ring_bytes,
317 debug_log_bytes,
318 io_queue_bytes,
319 &mut buffers,
320 )?;
321 Ok(MegakernelDispatchOutput { buffers, stats })
322 }
323
324 pub fn dispatch_with_io_queue_borrowed_into(
331 &self,
332 control_bytes: &[u8],
333 ring_bytes: &[u8],
334 debug_log_bytes: &[u8],
335 io_queue_bytes: &[u8],
336 outputs: &mut OutputBuffers,
337 ) -> Result<MegakernelDispatchStats, PipelineError> {
338 validate_control_bytes(control_bytes)?;
339 validate_debug_log_bytes(debug_log_bytes)?;
340 io::validate_io_queue_bytes(io_queue_bytes)?;
341 self.validate_ring_bytes(ring_bytes)?;
342
343 let input_bytes = total_len([control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes])?;
344 let inputs = [control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes];
345 let config = self.launch_geometry().dispatch_config(None);
346 let started = Instant::now();
347 let mut recovered = false;
348 match self.dispatch_once_into(&inputs, &config, outputs) {
349 Ok(()) => {}
350 Err(error) if self.recovery_policy.allows_retry(&error) => {
351 self.recover_after_device_loss()?;
352 recovered = true;
353 self.dispatch_once_into(&inputs, &config, outputs)?
354 }
355 Err(error) => return Err(error.into()),
356 }
357 let latency_ns = nanos_u64(started.elapsed().as_nanos())?;
358 let output_bytes = output_bytes(outputs)?;
359 let readback_bytes = output_bytes;
360 let bytes_moved = checked_add_u64(input_bytes, readback_bytes, "megakernel bytes moved")?;
361 let device_allocation_bytes = checked_add_u64(
362 input_bytes,
363 output_bytes,
364 "megakernel host-visible device allocation bytes",
365 )?;
366 let output_buffers = count_u32(outputs.len(), "megakernel output buffer count")?;
367 let device_allocation_events =
368 checked_add_u32(4, output_buffers, "megakernel allocation event count")?;
369 Ok(MegakernelDispatchStats {
370 input_bytes,
371 output_bytes,
372 readback_bytes,
373 bytes_moved,
374 device_allocation_bytes,
375 device_allocation_events,
376 latency_ns,
377 output_buffers,
378 resident_resource_rows: 0,
379 resident_resource_handles: 0,
380 kernel_launches: if recovered { 2 } else { 1 },
381 sync_points: 1,
382 recovered_after_device_loss: recovered,
383 })
384 }
385
386 pub fn recover_after_device_loss(&self) -> Result<MegakernelRecoveryDecision, PipelineError> {
395 let config = self.launch_geometry().dispatch_config(None);
396 let rebuilt = recover_compiled_pipeline(&self.backend, Arc::clone(&self.program), &config)?;
397 self.pipeline
398 .store(Arc::new(PipelineSlot { inner: rebuilt }));
399 Ok(MegakernelRecoveryDecision::RecompiledPipeline)
400 }
401
402 #[must_use]
404 pub fn pipeline_id(&self) -> &str {
405 &self.pipeline_id
406 }
407
408 #[must_use]
410 pub const fn slot_count(&self) -> u32 {
411 self.slot_count
412 }
413
414 #[must_use]
416 pub const fn workgroup_size_x(&self) -> u32 {
417 self.workgroup_size_x
418 }
419
420 #[must_use]
422 pub fn worker_groups(&self) -> u32 {
423 self.slot_count / self.workgroup_size_x
424 }
425
426 pub(super) fn validate_ring_bytes(&self, ring_bytes: &[u8]) -> Result<(), PipelineError> {
427 let expected_ring_bytes = protocol::ring_byte_len(self.slot_count).ok_or_else(|| {
428 PipelineError::Backend(
429 "megakernel ring byte length overflowed usize. Fix: split the ring into smaller dispatch shards."
430 .to_string(),
431 )
432 })?;
433 if ring_bytes.len() != expected_ring_bytes {
434 return Err(PipelineError::Backend(format!(
435 "megakernel ring buffer has {} bytes, expected {expected_ring_bytes} for {} slots. Fix: build ring bytes with Megakernel::encode_empty_ring(slot_count) for this handle.",
436 ring_bytes.len(),
437 self.slot_count
438 )));
439 }
440 Ok(())
441 }
442
443 pub(super) fn launch_geometry(&self) -> MegakernelLaunchGeometry {
444 MegakernelLaunchGeometry {
445 workgroup_size_x: self.workgroup_size_x,
446 slot_count: self.slot_count,
447 dispatch_grid: [self.slot_count / self.workgroup_size_x, 1, 1],
448 }
449 }
450
451 fn dispatch_once_into(
452 &self,
453 inputs: &[&[u8]; 4],
454 config: &DispatchConfig,
455 outputs: &mut OutputBuffers,
456 ) -> Result<(), vyre_driver::BackendError> {
457 if self.has_grid_sync && !self.backend.supports_grid_sync() {
458 return vyre_driver::grid_sync::dispatch_with_grid_sync_split_into(
459 self.backend.as_ref(),
460 &self.program,
461 inputs,
462 config,
463 outputs,
464 );
465 }
466 let pipeline = self.pipeline.load();
467 pipeline
468 .inner
469 .dispatch_borrowed_into(inputs, config, outputs)
470 }
471
472 fn dispatch_persistent_handles_once_into(
473 &self,
474 inputs: &[Resource; 4],
475 config: &DispatchConfig,
476 outputs: &mut OutputBuffers,
477 ) -> Result<(), vyre_driver::BackendError> {
478 let pipeline = self.pipeline.load();
479 pipeline
480 .inner
481 .dispatch_persistent_handles_into(inputs, config, outputs)
482 }
483
484 fn dispatch_persistent_handle_rows_once_into(
485 &self,
486 rows: &[[Resource; 4]],
487 config: &DispatchConfig,
488 outputs: &mut Vec<OutputBuffers>,
489 ) -> Result<(), vyre_driver::BackendError> {
490 let pipeline = self.pipeline.load();
491 pipeline
492 .inner
493 .dispatch_persistent_handle_rows_into(rows, config, outputs)
494 }
495}
496
497
498impl MegakernelRecoveryPolicy {
499 fn allows_retry(self, error: &vyre_driver::BackendError) -> bool {
500 self.retry_device_loss_once && backend_error_indicates_device_loss(error)
501 }
502}
503
504fn validate_bootstrap_geometry(
505 slot_count: u32,
506 workgroup_size_x: u32,
507) -> Result<(), PipelineError> {
508 if slot_count == 0 || workgroup_size_x == 0 || slot_count % workgroup_size_x != 0 {
509 return Err(PipelineError::QueueFull {
510 queue: "submission",
511 fix: "slot_count must be a non-zero multiple of workgroup_size_x",
512 });
513 }
514 Ok(())
515}
516
517pub(super) fn total_len<const N: usize>(buffers: [&[u8]; N]) -> Result<u64, PipelineError> {
518 let mut total = 0u64;
519 for buffer in buffers {
520 total = checked_add_u64(
521 total,
522 usize_to_u64(buffer.len(), "megakernel input buffer length")?,
523 "megakernel input byte total",
524 )?;
525 }
526 Ok(total)
527}
528
529pub(super) fn output_bytes(outputs: &[Vec<u8>]) -> Result<u64, PipelineError> {
530 let mut total = 0u64;
531 for output in outputs {
532 total = checked_add_u64(
533 total,
534 usize_to_u64(output.len(), "megakernel output buffer length")?,
535 "megakernel output byte total",
536 )?;
537 }
538 Ok(total)
539}
540
541pub(super) fn nested_output_bytes(outputs: &[Vec<Vec<u8>>]) -> Result<u64, PipelineError> {
542 let mut total = 0u64;
543 for row in outputs {
544 total = checked_add_u64(
545 total,
546 output_bytes(row)?,
547 "megakernel nested output byte total",
548 )?;
549 }
550 Ok(total)
551}
552
553pub(super) fn output_count_u32(outputs: &[Vec<u8>]) -> Result<u32, PipelineError> {
554 count_u32(outputs.len(), "megakernel output buffer count")
555}
556
557pub(super) fn nested_output_count_u32(outputs: &[Vec<Vec<u8>>]) -> Result<u32, PipelineError> {
558 let mut total = 0usize;
559 for row in outputs {
560 total = total.checked_add(row.len()).ok_or_else(|| {
561 PipelineError::Backend(
562 "megakernel nested output buffer count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
563 )
564 })?;
565 }
566 count_u32(total, "megakernel nested output buffer count")
567}
568
569pub(super) fn resident_row_count_u32(rows: usize) -> Result<u32, PipelineError> {
570 count_u32(rows, "megakernel resident resource row count")
571}
572
573pub(super) fn resident_handle_count_u32(rows: usize) -> Result<u32, PipelineError> {
574 let handles = rows
575 .checked_mul(MegakernelResidentHandles::ABI_RESOURCE_COUNT)
576 .ok_or_else(|| {
577 PipelineError::Backend(
578 "megakernel resident resource handle count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
579 )
580 })?;
581 count_u32(handles, "megakernel resident resource handle count")
582}
583
584pub(super) fn reserve_output_shell<T>(
585 out: &mut Vec<T>,
586 capacity: usize,
587 label: &'static str,
588) -> Result<(), PipelineError> {
589 reserve_vec_capacity(out, capacity, label)
590}
591
592pub(super) fn nanos_u64(nanos: u128) -> Result<u64, PipelineError> {
593 u64::try_from(nanos).map_err(|source| {
594 PipelineError::Backend(format!(
595 "megakernel latency cannot fit u64 nanoseconds: {source}. Fix: timeout or split the dispatch before telemetry overflows."
596 ))
597 })
598}
599
600fn usize_to_u64(value: usize, label: &str) -> Result<u64, PipelineError> {
601 u64::try_from(value).map_err(|source| {
602 PipelineError::Backend(format!(
603 "{label} cannot fit u64: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
604 ))
605 })
606}
607
608fn count_u32(value: usize, label: &str) -> Result<u32, PipelineError> {
609 u32::try_from(value).map_err(|source| {
610 PipelineError::Backend(format!(
611 "{label} cannot fit u32: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
612 ))
613 })
614}
615
616fn checked_add_u64(left: u64, right: u64, label: &str) -> Result<u64, PipelineError> {
617 left.checked_add(right).ok_or_else(|| {
618 PipelineError::Backend(format!(
619 "{label} overflowed u64. Fix: split the megakernel dispatch before telemetry/accounting."
620 ))
621 })
622}
623
624fn checked_add_u32(left: u32, right: u32, label: &str) -> Result<u32, PipelineError> {
625 left.checked_add(right).ok_or_else(|| {
626 PipelineError::Backend(format!(
627 "{label} overflowed u32. Fix: split the megakernel dispatch before telemetry/accounting."
628 ))
629 })
630}
631
632#[cfg(test)]
633mod tests;
634