1mod persistent_handles;
4mod readback_dispatch;
5mod types;
6
7use super::builder::{build_program_jit_slots, build_program_sharded_slots_shared};
8use super::handlers::OpcodeHandler;
9use super::io;
10use super::planner::MegakernelLaunchGeometry;
11use super::protocol;
12use super::protocol_api::{validate_control_bytes, validate_debug_log_bytes};
13use super::recovery::{
14 backend_error_indicates_device_loss, recover_compiled_pipeline, MegakernelRecoveryDecision,
15 MegakernelRecoveryPolicy,
16};
17use super::staging_reserve::reserve_vec_capacity;
18use crate::PipelineError;
19use arc_swap::ArcSwap;
20use std::sync::Arc;
21use std::time::Instant;
22use vyre_driver::backend::{
23 CompiledPipeline, DispatchConfig, OutputBuffers, Resource, VyreBackend,
24};
25use vyre_foundation::ir::Program;
26
27pub use types::{
28 MegakernelBatchDispatchOutput, MegakernelDispatchOutput, MegakernelDispatchStats,
29 MegakernelResidentBatchScratch, MegakernelResidentHandles,
30};
31
32pub struct Megakernel {
38 backend: Arc<dyn VyreBackend>,
39 pipeline: ArcSwap<PipelineSlot>,
40 pipeline_id: String,
41 program: Arc<Program>,
42 has_grid_sync: bool,
43 empty_io_queue_bytes: Arc<[u8]>,
44 slot_count: u32,
45 workgroup_size_x: u32,
46 recovery_policy: MegakernelRecoveryPolicy,
47}
48
49struct PipelineSlot {
50 inner: Arc<dyn CompiledPipeline>,
51}
52
53impl Megakernel {
54 pub fn bootstrap(backend: Arc<dyn VyreBackend>) -> Result<Self, PipelineError> {
60 Self::bootstrap_sharded(backend, 256, 256, Vec::new())
61 }
62
63 pub fn bootstrap_with_opcodes(
69 backend: Arc<dyn VyreBackend>,
70 opcodes: Vec<OpcodeHandler>,
71 ) -> Result<Self, PipelineError> {
72 Self::bootstrap_sharded(backend, 256, 256, opcodes)
73 }
74
75 pub fn worker_groups_for_geometry(
82 slot_count: u32,
83 workgroup_size_x: u32,
84 ) -> Result<u32, PipelineError> {
85 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
86 Ok(slot_count / workgroup_size_x)
87 }
88
89 pub fn bootstrap_sharded(
96 backend: Arc<dyn VyreBackend>,
97 slot_count: u32,
98 workgroup_size_x: u32,
99 opcodes: Vec<OpcodeHandler>,
100 ) -> Result<Self, PipelineError> {
101 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
102 let program = build_program_sharded_slots_shared(workgroup_size_x, slot_count, &opcodes);
103 Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, program)
104 }
105
106 pub fn bootstrap_jit(
112 backend: Arc<dyn VyreBackend>,
113 slot_count: u32,
114 workgroup_size_x: u32,
115 payload_processor: &[vyre_foundation::ir::Node],
116 ) -> Result<Self, PipelineError> {
117 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
118 let program = build_program_jit_slots(workgroup_size_x, slot_count, payload_processor);
119 Self::compile_bootstrap(backend, slot_count, workgroup_size_x, program)
120 }
121
122 fn compile_bootstrap(
123 backend: Arc<dyn VyreBackend>,
124 slot_count: u32,
125 workgroup_size_x: u32,
126 program: Program,
127 ) -> Result<Self, PipelineError> {
128 Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, Arc::new(program))
129 }
130
131 fn compile_bootstrap_shared(
132 backend: Arc<dyn VyreBackend>,
133 slot_count: u32,
134 workgroup_size_x: u32,
135 program: Arc<Program>,
136 ) -> Result<Self, PipelineError> {
137 validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
138 let config = DispatchConfig::default();
139 let pipeline = vyre_driver::pipeline::compile_shared(
140 Arc::clone(&backend),
141 Arc::clone(&program),
142 &config,
143 )?;
144 let pipeline_id = pipeline.id().to_string();
145 let has_grid_sync = vyre_driver::grid_sync::contains_grid_sync(&program);
146 let empty_io_queue_bytes =
147 Arc::<[u8]>::from(io::try_encode_empty_io_queue(io::IO_SLOT_COUNT)?.into_boxed_slice());
148 Ok(Self {
149 backend,
150 pipeline: ArcSwap::from(Arc::new(PipelineSlot { inner: pipeline })),
151 pipeline_id,
152 program,
153 has_grid_sync,
154 empty_io_queue_bytes,
155 slot_count,
156 workgroup_size_x,
157 recovery_policy: MegakernelRecoveryPolicy::default(),
158 })
159 }
160
161 pub fn dispatch(
168 &self,
169 control_bytes: Vec<u8>,
170 ring_bytes: Vec<u8>,
171 debug_log_bytes: Vec<u8>,
172 ) -> Result<Vec<Vec<u8>>, PipelineError> {
173 self.dispatch_borrowed(&control_bytes, &ring_bytes, &debug_log_bytes)
174 }
175
176 pub fn dispatch_borrowed(
183 &self,
184 control_bytes: &[u8],
185 ring_bytes: &[u8],
186 debug_log_bytes: &[u8],
187 ) -> Result<Vec<Vec<u8>>, PipelineError> {
188 Ok(self
189 .dispatch_borrowed_observed(control_bytes, ring_bytes, debug_log_bytes)?
190 .buffers)
191 }
192
193 pub fn dispatch_observed(
199 &self,
200 control_bytes: Vec<u8>,
201 ring_bytes: Vec<u8>,
202 debug_log_bytes: Vec<u8>,
203 ) -> Result<MegakernelDispatchOutput, PipelineError> {
204 self.dispatch_with_io_queue_borrowed_observed(
205 &control_bytes,
206 &ring_bytes,
207 &debug_log_bytes,
208 &self.empty_io_queue_bytes,
209 )
210 }
211
212 pub fn dispatch_borrowed_observed(
219 &self,
220 control_bytes: &[u8],
221 ring_bytes: &[u8],
222 debug_log_bytes: &[u8],
223 ) -> Result<MegakernelDispatchOutput, PipelineError> {
224 self.dispatch_with_io_queue_borrowed_observed(
225 control_bytes,
226 ring_bytes,
227 debug_log_bytes,
228 &self.empty_io_queue_bytes,
229 )
230 }
231
232 pub fn dispatch_with_io_queue(
239 &self,
240 control_bytes: Vec<u8>,
241 ring_bytes: Vec<u8>,
242 debug_log_bytes: Vec<u8>,
243 io_queue_bytes: Vec<u8>,
244 ) -> Result<Vec<Vec<u8>>, PipelineError> {
245 self.dispatch_with_io_queue_borrowed(
246 &control_bytes,
247 &ring_bytes,
248 &debug_log_bytes,
249 &io_queue_bytes,
250 )
251 }
252
253 pub fn dispatch_with_io_queue_borrowed(
259 &self,
260 control_bytes: &[u8],
261 ring_bytes: &[u8],
262 debug_log_bytes: &[u8],
263 io_queue_bytes: &[u8],
264 ) -> Result<Vec<Vec<u8>>, PipelineError> {
265 Ok(self
266 .dispatch_with_io_queue_borrowed_observed(
267 control_bytes,
268 ring_bytes,
269 debug_log_bytes,
270 io_queue_bytes,
271 )?
272 .buffers)
273 }
274
275 pub fn dispatch_with_io_queue_observed(
281 &self,
282 control_bytes: Vec<u8>,
283 ring_bytes: Vec<u8>,
284 debug_log_bytes: Vec<u8>,
285 io_queue_bytes: Vec<u8>,
286 ) -> Result<MegakernelDispatchOutput, PipelineError> {
287 self.dispatch_with_io_queue_borrowed_observed(
288 &control_bytes,
289 &ring_bytes,
290 &debug_log_bytes,
291 &io_queue_bytes,
292 )
293 }
294
295 pub fn dispatch_with_io_queue_borrowed_observed(
302 &self,
303 control_bytes: &[u8],
304 ring_bytes: &[u8],
305 debug_log_bytes: &[u8],
306 io_queue_bytes: &[u8],
307 ) -> Result<MegakernelDispatchOutput, PipelineError> {
308 let mut buffers = Vec::new();
309 reserve_output_shell(
310 &mut buffers,
311 MegakernelResidentHandles::ABI_RESOURCE_COUNT,
312 "borrowed megakernel output shell",
313 )?;
314 let stats = self.dispatch_with_io_queue_borrowed_into(
315 control_bytes,
316 ring_bytes,
317 debug_log_bytes,
318 io_queue_bytes,
319 &mut buffers,
320 )?;
321 Ok(MegakernelDispatchOutput { buffers, stats })
322 }
323
324 pub fn dispatch_with_io_queue_borrowed_into(
331 &self,
332 control_bytes: &[u8],
333 ring_bytes: &[u8],
334 debug_log_bytes: &[u8],
335 io_queue_bytes: &[u8],
336 outputs: &mut OutputBuffers,
337 ) -> Result<MegakernelDispatchStats, PipelineError> {
338 validate_control_bytes(control_bytes)?;
339 validate_debug_log_bytes(debug_log_bytes)?;
340 io::validate_io_queue_bytes(io_queue_bytes)?;
341 self.validate_ring_bytes(ring_bytes)?;
342
343 let input_bytes = total_len([control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes])?;
344 let inputs = [control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes];
345 let config = self.launch_geometry().dispatch_config(None);
346 let started = Instant::now();
347 let mut recovered = false;
348 match self.dispatch_once_into(&inputs, &config, outputs) {
349 Ok(()) => {}
350 Err(error) if self.recovery_policy.allows_retry(&error) => {
351 self.recover_after_device_loss()?;
352 recovered = true;
353 self.dispatch_once_into(&inputs, &config, outputs)?
354 }
355 Err(error) => return Err(error.into()),
356 }
357 let latency_ns = nanos_u64(started.elapsed().as_nanos())?;
358 let output_bytes = output_bytes(outputs)?;
359 let readback_bytes = output_bytes;
360 let bytes_moved = checked_add_u64(input_bytes, readback_bytes, "megakernel bytes moved")?;
361 let device_allocation_bytes = checked_add_u64(
362 input_bytes,
363 output_bytes,
364 "megakernel host-visible device allocation bytes",
365 )?;
366 let output_buffers = count_u32(outputs.len(), "megakernel output buffer count")?;
367 let device_allocation_events =
368 checked_add_u32(4, output_buffers, "megakernel allocation event count")?;
369 Ok(MegakernelDispatchStats {
370 input_bytes,
371 output_bytes,
372 readback_bytes,
373 bytes_moved,
374 device_allocation_bytes,
375 device_allocation_events,
376 latency_ns,
377 output_buffers,
378 resident_resource_rows: 0,
379 resident_resource_handles: 0,
380 kernel_launches: if recovered { 2 } else { 1 },
381 sync_points: 1,
382 recovered_after_device_loss: recovered,
383 })
384 }
385
386 pub fn recover_after_device_loss(&self) -> Result<MegakernelRecoveryDecision, PipelineError> {
395 let config = self.launch_geometry().dispatch_config(None);
396 let rebuilt = recover_compiled_pipeline(&self.backend, Arc::clone(&self.program), &config)?;
397 self.pipeline
398 .store(Arc::new(PipelineSlot { inner: rebuilt }));
399 Ok(MegakernelRecoveryDecision::RecompiledPipeline)
400 }
401
402 #[must_use]
404 pub fn pipeline_id(&self) -> &str {
405 &self.pipeline_id
406 }
407
408 #[must_use]
410 pub const fn slot_count(&self) -> u32 {
411 self.slot_count
412 }
413
414 #[must_use]
416 pub const fn workgroup_size_x(&self) -> u32 {
417 self.workgroup_size_x
418 }
419
420 #[must_use]
422 pub fn worker_groups(&self) -> u32 {
423 self.slot_count / self.workgroup_size_x
424 }
425
426 pub(super) fn validate_ring_bytes(&self, ring_bytes: &[u8]) -> Result<(), PipelineError> {
427 let expected_ring_bytes = protocol::ring_byte_len(self.slot_count).ok_or_else(|| {
428 PipelineError::Backend(
429 "megakernel ring byte length overflowed usize. Fix: split the ring into smaller dispatch shards."
430 .to_string(),
431 )
432 })?;
433 if ring_bytes.len() != expected_ring_bytes {
434 return Err(PipelineError::Backend(format!(
435 "megakernel ring buffer has {} bytes, expected {expected_ring_bytes} for {} slots. Fix: build ring bytes with Megakernel::encode_empty_ring(slot_count) for this handle.",
436 ring_bytes.len(),
437 self.slot_count
438 )));
439 }
440 Ok(())
441 }
442
443 pub(super) fn launch_geometry(&self) -> MegakernelLaunchGeometry {
444 MegakernelLaunchGeometry {
445 workgroup_size_x: self.workgroup_size_x,
446 slot_count: self.slot_count,
447 dispatch_grid: [self.slot_count / self.workgroup_size_x, 1, 1],
448 }
449 }
450
451 fn dispatch_once_into(
452 &self,
453 inputs: &[&[u8]; 4],
454 config: &DispatchConfig,
455 outputs: &mut OutputBuffers,
456 ) -> Result<(), vyre_driver::BackendError> {
457 if self.has_grid_sync && !self.backend.supports_grid_sync() {
458 return vyre_driver::grid_sync::dispatch_with_grid_sync_split_into(
459 self.backend.as_ref(),
460 &self.program,
461 inputs,
462 config,
463 outputs,
464 );
465 }
466 let pipeline = self.pipeline.load();
467 pipeline
468 .inner
469 .dispatch_borrowed_into(inputs, config, outputs)
470 }
471
472 fn dispatch_persistent_handles_once_into(
473 &self,
474 inputs: &[Resource; 4],
475 config: &DispatchConfig,
476 outputs: &mut OutputBuffers,
477 ) -> Result<(), vyre_driver::BackendError> {
478 let pipeline = self.pipeline.load();
479 pipeline
480 .inner
481 .dispatch_persistent_handles_into(inputs, config, outputs)
482 }
483
484 fn dispatch_persistent_handle_rows_once_into(
485 &self,
486 rows: &[[Resource; 4]],
487 config: &DispatchConfig,
488 outputs: &mut Vec<OutputBuffers>,
489 ) -> Result<(), vyre_driver::BackendError> {
490 let pipeline = self.pipeline.load();
491 pipeline
492 .inner
493 .dispatch_persistent_handle_rows_into(rows, config, outputs)
494 }
495}
496
497impl MegakernelRecoveryPolicy {
498 fn allows_retry(self, error: &vyre_driver::BackendError) -> bool {
499 self.retry_device_loss_once && backend_error_indicates_device_loss(error)
500 }
501}
502
503fn validate_bootstrap_geometry(
504 slot_count: u32,
505 workgroup_size_x: u32,
506) -> Result<(), PipelineError> {
507 if slot_count == 0 || workgroup_size_x == 0 || slot_count % workgroup_size_x != 0 {
508 return Err(PipelineError::QueueFull {
509 queue: "submission",
510 fix: "slot_count must be a non-zero multiple of workgroup_size_x",
511 });
512 }
513 Ok(())
514}
515
516pub(super) fn total_len<const N: usize>(buffers: [&[u8]; N]) -> Result<u64, PipelineError> {
517 let mut total = 0u64;
518 for buffer in buffers {
519 total = checked_add_u64(
520 total,
521 usize_to_u64(buffer.len(), "megakernel input buffer length")?,
522 "megakernel input byte total",
523 )?;
524 }
525 Ok(total)
526}
527
528pub(super) fn output_bytes(outputs: &[Vec<u8>]) -> Result<u64, PipelineError> {
529 let mut total = 0u64;
530 for output in outputs {
531 total = checked_add_u64(
532 total,
533 usize_to_u64(output.len(), "megakernel output buffer length")?,
534 "megakernel output byte total",
535 )?;
536 }
537 Ok(total)
538}
539
540pub(super) fn nested_output_bytes(outputs: &[Vec<Vec<u8>>]) -> Result<u64, PipelineError> {
541 let mut total = 0u64;
542 for row in outputs {
543 total = checked_add_u64(
544 total,
545 output_bytes(row)?,
546 "megakernel nested output byte total",
547 )?;
548 }
549 Ok(total)
550}
551
552pub(super) fn output_count_u32(outputs: &[Vec<u8>]) -> Result<u32, PipelineError> {
553 count_u32(outputs.len(), "megakernel output buffer count")
554}
555
556pub(super) fn nested_output_count_u32(outputs: &[Vec<Vec<u8>>]) -> Result<u32, PipelineError> {
557 let mut total = 0usize;
558 for row in outputs {
559 total = total.checked_add(row.len()).ok_or_else(|| {
560 PipelineError::Backend(
561 "megakernel nested output buffer count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
562 )
563 })?;
564 }
565 count_u32(total, "megakernel nested output buffer count")
566}
567
568pub(super) fn resident_row_count_u32(rows: usize) -> Result<u32, PipelineError> {
569 count_u32(rows, "megakernel resident resource row count")
570}
571
572pub(super) fn resident_handle_count_u32(rows: usize) -> Result<u32, PipelineError> {
573 let handles = rows
574 .checked_mul(MegakernelResidentHandles::ABI_RESOURCE_COUNT)
575 .ok_or_else(|| {
576 PipelineError::Backend(
577 "megakernel resident resource handle count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
578 )
579 })?;
580 count_u32(handles, "megakernel resident resource handle count")
581}
582
583pub(super) fn reserve_output_shell<T>(
584 out: &mut Vec<T>,
585 capacity: usize,
586 label: &'static str,
587) -> Result<(), PipelineError> {
588 reserve_vec_capacity(out, capacity, label)
589}
590
591pub(super) fn nanos_u64(nanos: u128) -> Result<u64, PipelineError> {
592 u64::try_from(nanos).map_err(|source| {
593 PipelineError::Backend(format!(
594 "megakernel latency cannot fit u64 nanoseconds: {source}. Fix: timeout or split the dispatch before telemetry overflows."
595 ))
596 })
597}
598
599fn usize_to_u64(value: usize, label: &str) -> Result<u64, PipelineError> {
600 u64::try_from(value).map_err(|source| {
601 PipelineError::Backend(format!(
602 "{label} cannot fit u64: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
603 ))
604 })
605}
606
607fn count_u32(value: usize, label: &str) -> Result<u32, PipelineError> {
608 u32::try_from(value).map_err(|source| {
609 PipelineError::Backend(format!(
610 "{label} cannot fit u32: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
611 ))
612 })
613}
614
615fn checked_add_u64(left: u64, right: u64, label: &str) -> Result<u64, PipelineError> {
616 left.checked_add(right).ok_or_else(|| {
617 PipelineError::Backend(format!(
618 "{label} overflowed u64. Fix: split the megakernel dispatch before telemetry/accounting."
619 ))
620 })
621}
622
623fn checked_add_u32(left: u32, right: u32, label: &str) -> Result<u32, PipelineError> {
624 left.checked_add(right).ok_or_else(|| {
625 PipelineError::Backend(format!(
626 "{label} overflowed u32. Fix: split the megakernel dispatch before telemetry/accounting."
627 ))
628 })
629}
630
631#[cfg(test)]
632mod tests;