1use std::sync::Arc;
8
9use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
10
11use super::handlers::{claimed_slot_bindings, claimed_slot_body, load_miss_body, OpcodeHandler};
12use super::io::{
13 io_word, IO_DESTINATION_CAPABILITY_TABLE, IO_QUEUE_DMA_TAG, IO_SLOT_COUNT, IO_SLOT_WORDS,
14 IO_SOURCE_CAPABILITY_TABLE,
15};
16use super::ir_util::atomic_load_relaxed;
17use super::protocol::*;
18use super::workspace_adapter::MegakernelWorkspaceAdapter;
19mod cache;
20mod jit;
21mod priority;
22pub use jit::{build_program_jit, build_program_jit_slots, persistent_body_jit};
23pub use priority::{
24 build_program_priority, build_program_priority_slots, persistent_body_priority,
25 persistent_body_priority_slots,
26};
27
28#[must_use]
30pub fn build_program() -> Program {
31 build_program_sharded(256, &[])
32}
33
34#[must_use]
42pub fn build_program_sharded(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
43 build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
44}
45
46#[must_use]
52pub fn build_program_sharded_slots(
53 workgroup_size_x: u32,
54 slot_count: u32,
55 opcodes: &[OpcodeHandler],
56) -> Program {
57 build_program_sharded_slots_with_io(workgroup_size_x, slot_count, opcodes, false)
58}
59
60#[must_use]
66pub fn build_program_sharded_slots_shared(
67 workgroup_size_x: u32,
68 slot_count: u32,
69 opcodes: &[OpcodeHandler],
70) -> Arc<Program> {
71 if opcodes.is_empty() {
72 return cache::cached_empty_sharded_program_shared(workgroup_size_x, slot_count, false);
73 }
74 Arc::new(build_program_sharded_slots(
75 workgroup_size_x,
76 slot_count,
77 opcodes,
78 ))
79}
80
81#[must_use]
83pub fn build_program_sharded_with_workspace_adapter(
84 workgroup_size_x: u32,
85 slot_count: u32,
86 opcodes: &[OpcodeHandler],
87 adapter: &impl MegakernelWorkspaceAdapter,
88) -> Program {
89 wrap_persistent_megakernel_program_with_buffers(
90 default_buffers_with_workspace_adapter(slot_count, adapter),
91 workgroup_size_x,
92 persistent_body_with_workspace_adapter(workgroup_size_x, opcodes, adapter),
93 )
94}
95
96#[must_use]
103pub fn build_program_sharded_once_slots(
104 workgroup_size_x: u32,
105 slot_count: u32,
106 opcodes: &[OpcodeHandler],
107) -> Program {
108 if opcodes.is_empty() {
109 return cache::cached_empty_sharded_once_program(workgroup_size_x, slot_count);
110 }
111 wrap_megakernel_program(
112 workgroup_size_x,
113 slot_count,
114 persistent_body_with_io(workgroup_size_x, opcodes, false),
115 )
116}
117
118#[must_use]
121pub fn build_program_sharded_once_slots_shared(
122 workgroup_size_x: u32,
123 slot_count: u32,
124 opcodes: &[OpcodeHandler],
125) -> Arc<Program> {
126 if opcodes.is_empty() {
127 return cache::cached_empty_sharded_once_program_shared(workgroup_size_x, slot_count);
128 }
129 Arc::new(build_program_sharded_once_slots(
130 workgroup_size_x,
131 slot_count,
132 opcodes,
133 ))
134}
135
136#[must_use]
144pub fn build_program_sharded_once_slots_control_report_shared(
145 workgroup_size_x: u32,
146 slot_count: u32,
147 opcodes: &[OpcodeHandler],
148) -> Arc<Program> {
149 if opcodes.is_empty() {
150 return cache::cached_empty_sharded_once_control_report_program_shared(
151 workgroup_size_x,
152 slot_count,
153 );
154 }
155 let mut buffers = default_buffers(slot_count);
156 for buffer in buffers.iter_mut().skip(1) {
157 buffer.output_byte_range = Some(0..0);
158 }
159 Arc::new(optimize_megakernel_program(Program::wrapped(
160 buffers,
161 [workgroup_size_x, 1, 1],
162 persistent_body_with_io(workgroup_size_x, opcodes, false),
163 )))
164}
165
166#[must_use]
172pub fn build_program_sharded_no_io(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
173 build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
174}
175
176#[must_use]
181pub fn build_program_sharded_with_io_polling(
182 workgroup_size_x: u32,
183 opcodes: &[OpcodeHandler],
184) -> Program {
185 build_program_sharded_slots_with_io(workgroup_size_x, workgroup_size_x.max(1), opcodes, true)
186}
187
188#[must_use]
196#[cfg(any(test, feature = "legacy-infallible"))]
197pub fn build_program_with_self_loading_miss_handler(
198 workgroup_size_x: u32,
199 slot_count: u32,
200 opcodes: &[OpcodeHandler],
201) -> Program {
202 match try_build_program_with_self_loading_miss_handler(workgroup_size_x, slot_count, opcodes) {
203 Ok(program) => program,
204 Err(error) => panic!("{error}"),
205 }
206}
207
208pub fn try_build_program_with_self_loading_miss_handler(
210 workgroup_size_x: u32,
211 slot_count: u32,
212 opcodes: &[OpcodeHandler],
213) -> Result<Program, String> {
214 let mut extended = Vec::new();
215 let extended_len = opcodes.len().checked_add(1).ok_or_else(|| {
216 "megakernel self-loading opcode extension count overflowed usize. Fix: split opcode handler sets before building the megakernel."
217 .to_string()
218 })?;
219 vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut extended, extended_len).map_err(|error| {
220 format!(
221 "megakernel self-loading opcode extension allocation failed: {error}. Fix: split opcode handler sets before building the megakernel."
222 )
223 })?;
224 extended.extend_from_slice(opcodes);
225 extended.push(OpcodeHandler {
226 opcode: super::protocol::opcode::LOAD_MISS,
227 body: load_miss_body(),
228 });
229 Ok(wrap_persistent_megakernel_program(
230 workgroup_size_x,
231 slot_count,
232 persistent_body_with_io(workgroup_size_x, &extended, false),
233 ))
234}
235
236fn build_program_sharded_slots_with_io(
237 workgroup_size_x: u32,
238 slot_count: u32,
239 opcodes: &[OpcodeHandler],
240 include_io_polling: bool,
241) -> Program {
242 if opcodes.is_empty() {
243 return cache::cached_empty_sharded_program(
244 workgroup_size_x,
245 slot_count,
246 include_io_polling,
247 );
248 }
249 wrap_persistent_megakernel_program(
250 workgroup_size_x,
251 slot_count,
252 persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling),
253 )
254}
255
256fn wrap_persistent_megakernel_program(
257 workgroup_size_x: u32,
258 slot_count: u32,
259 body: Vec<Node>,
260) -> Program {
261 wrap_megakernel_program(workgroup_size_x, slot_count, vec![Node::forever(body)])
262}
263
264fn wrap_persistent_megakernel_program_with_buffers(
265 buffers: Vec<BufferDecl>,
266 workgroup_size_x: u32,
267 body: Vec<Node>,
268) -> Program {
269 optimize_megakernel_program(Program::wrapped(
270 buffers,
271 [workgroup_size_x, 1, 1],
272 vec![Node::forever(body)],
273 ))
274}
275
276fn wrap_megakernel_program(workgroup_size_x: u32, slot_count: u32, body: Vec<Node>) -> Program {
277 optimize_megakernel_program(Program::wrapped(
278 default_buffers(slot_count),
279 [workgroup_size_x, 1, 1],
280 body,
281 ))
282}
283
284fn optimize_megakernel_program(program: Program) -> Program {
285 let fallback = program.clone();
286 let program = match super::planner::try_elide_value_flow_barriers(program) {
287 Ok((program, _)) => program,
288 Err(_) => fallback,
289 };
290 vyre_foundation::optimizer::pre_lowering::optimize(program)
291}
292
293fn default_buffers(slot_count: u32) -> Vec<BufferDecl> {
307 let ring_slots = slot_count.max(1);
308 let control = BufferDecl::read_write("control", 0, DataType::U32).with_count(CONTROL_MIN_WORDS);
309 let ring_buffer = BufferDecl::read_write("ring_buffer", 1, DataType::U32)
310 .with_count(ring_slots.saturating_mul(SLOT_WORDS));
311 let debug_log =
312 BufferDecl::read_write("debug_log", 2, DataType::U32).with_count(debug::BUFFER_WORDS);
313 let io_queue = BufferDecl::read_write("io_queue", 3, DataType::U32).with_count(64 * 8);
314 vec![control, ring_buffer, debug_log, io_queue]
315}
316
317fn default_buffers_with_workspace_adapter(
318 slot_count: u32,
319 adapter: &impl MegakernelWorkspaceAdapter,
320) -> Vec<BufferDecl> {
321 let mut buffers = default_buffers(slot_count);
322 buffers.push(adapter.buffer_decl());
323 buffers
324}
325
326#[must_use]
329pub fn persistent_body(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Vec<Node> {
330 persistent_body_with_io(workgroup_size_x, opcodes, false)
331}
332
333pub fn try_persistent_body(
335 workgroup_size_x: u32,
336 opcodes: &[OpcodeHandler],
337) -> Result<Vec<Node>, String> {
338 try_persistent_body_with_io(workgroup_size_x, opcodes, false)
339}
340
341fn persistent_body_with_io(
342 workgroup_size_x: u32,
343 opcodes: &[OpcodeHandler],
344 include_io_polling: bool,
345) -> Vec<Node> {
346 let mut body = persistent_lane_prologue(workgroup_size_x);
347 let additional_nodes = if include_io_polling { 3 } else { 2 };
348 if let Some(body_capacity) = body.len().checked_add(additional_nodes) {
349 let _ = vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity);
350 }
351 body.push(direct_slot_base_binding());
352 body.push(Node::Block(execute_slot_body(opcodes)));
353 if include_io_polling {
354 body.push(Node::Block(process_io_requests()));
355 }
356 body
357}
358
359fn try_persistent_body_with_io(
360 workgroup_size_x: u32,
361 opcodes: &[OpcodeHandler],
362 include_io_polling: bool,
363) -> Result<Vec<Node>, String> {
364 let mut body = persistent_lane_prologue(workgroup_size_x);
365 let additional_nodes = if include_io_polling { 3 } else { 2 };
366 let body_capacity = body.len().checked_add(additional_nodes).ok_or_else(|| {
367 "megakernel persistent body node reservation overflowed usize. Fix: reduce fused IO/body staging before building the megakernel."
368 .to_string()
369 })?;
370 vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity).map_err(|error| {
371 format!(
372 "megakernel persistent body node reservation failed: {error}. Fix: reduce fused IO/body staging before building the megakernel."
373 )
374 })?;
375 body.push(direct_slot_base_binding());
376 body.push(Node::Block(execute_slot_body(opcodes)));
377 if include_io_polling {
378 body.push(Node::Block(process_io_requests()));
379 }
380 Ok(body)
381}
382
383fn persistent_lane_prologue(workgroup_size_x: u32) -> Vec<Node> {
384 vec![
385 Node::let_bind(
386 "shutdown_flag",
387 atomic_load_relaxed("control", Expr::u32(control::SHUTDOWN)),
388 ),
389 Node::if_then(
390 Expr::ne(Expr::var("shutdown_flag"), Expr::u32(0)),
391 vec![Node::Return],
392 ),
393 Node::let_bind("lane_id", lane_id_expr(workgroup_size_x)),
394 ]
395}
396
397fn direct_slot_base_binding() -> Node {
398 Node::let_bind(
399 "slot_base",
400 Expr::mul(Expr::var("lane_id"), Expr::u32(SLOT_WORDS)),
401 )
402}
403
404fn slot_tenant_id_load() -> Expr {
405 Expr::load(
406 "ring_buffer",
407 Expr::add(Expr::var("slot_base"), Expr::u32(TENANT_WORD)),
408 )
409}
410
411fn tenant_authorized_body(tenant_id: Expr, authorized_body: Vec<Node>) -> Vec<Node> {
412 vec![
413 Node::let_bind("tenant_id", tenant_id),
414 Node::let_bind(
415 "tenant_base",
416 atomic_load_relaxed("control", Expr::u32(control::TENANT_BASE)),
417 ),
418 Node::let_bind(
419 "tenant_mask",
420 atomic_load_relaxed(
421 "control",
422 Expr::add(Expr::var("tenant_base"), Expr::var("tenant_id")),
423 ),
424 ),
425 Node::if_then(
426 Expr::ne(Expr::var("tenant_mask"), Expr::u32(0)),
427 authorized_body,
428 ),
429 ]
430}
431
432fn lane_id_expr(workgroup_size_x: u32) -> Expr {
433 Expr::add(
434 Expr::mul(Expr::workgroup_x(), Expr::u32(workgroup_size_x)),
435 Expr::local_x(),
436 )
437}
438
439fn persistent_body_with_workspace_adapter(
440 workgroup_size_x: u32,
441 opcodes: &[OpcodeHandler],
442 adapter: &impl MegakernelWorkspaceAdapter,
443) -> Vec<Node> {
444 let mut body = adapter.bootstrap_nodes();
445 body.extend(adapter.guard_nodes());
446 body.extend(adapter.dispatch_nodes());
447 body.extend(persistent_body_with_io(workgroup_size_x, opcodes, false));
448 body
449}
450
451fn process_io_requests() -> Vec<Node> {
452 let nodes = vec![Node::loop_for(
453 "io_idx",
454 Expr::u32(0),
455 Expr::u32(IO_SLOT_COUNT),
456 vec![
457 Node::let_bind(
458 "io_base",
459 Expr::mul(Expr::var("io_idx"), Expr::u32(IO_SLOT_WORDS)),
460 ),
461 Node::let_bind(
462 "io_status_idx",
463 Expr::add(Expr::var("io_base"), Expr::u32(io_word::STATUS)),
464 ),
465 Node::let_bind(
467 "prev_io_status",
468 Expr::atomic_compare_exchange(
469 "io_queue",
470 Expr::var("io_status_idx"),
471 Expr::u32(slot::PUBLISHED),
472 Expr::u32(slot::CLAIMED),
473 ),
474 ),
475 Node::if_then(
476 Expr::eq(Expr::var("prev_io_status"), Expr::u32(slot::PUBLISHED)),
477 vec![
478 Node::let_bind(
479 "io_src_handle",
480 Expr::load(
481 "io_queue",
482 Expr::add(Expr::var("io_base"), Expr::u32(io_word::SRC_HANDLE)),
483 ),
484 ),
485 Node::let_bind(
486 "io_dst_handle",
487 Expr::load(
488 "io_queue",
489 Expr::add(Expr::var("io_base"), Expr::u32(io_word::DST_HANDLE)),
490 ),
491 ),
492 Node::AsyncLoad {
493 source: IO_SOURCE_CAPABILITY_TABLE.into(),
494 destination: IO_DESTINATION_CAPABILITY_TABLE.into(),
495 offset: Box::new(Expr::load(
496 "io_queue",
497 Expr::add(Expr::var("io_base"), Expr::u32(io_word::OFFSET_LO)),
498 )),
499 size: Box::new(Expr::load(
500 "io_queue",
501 Expr::add(Expr::var("io_base"), Expr::u32(io_word::BYTE_COUNT)),
502 )),
503 tag: IO_QUEUE_DMA_TAG.into(),
504 },
505 Node::store(
507 "io_queue",
508 Expr::var("io_status_idx"),
509 Expr::u32(slot::DONE),
510 ),
511 ],
512 ),
513 ],
514 )];
515
516 nodes
517}
518
519fn execute_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
520 vec![
521 Node::let_bind(
522 "status_index",
523 Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
524 ),
525 Node::let_bind(
526 "observed_status",
527 atomic_load_relaxed("ring_buffer", Expr::var("status_index")),
528 ),
529 Node::if_then(
530 Expr::eq(Expr::var("observed_status"), Expr::u32(slot::PUBLISHED)),
531 tenant_authorized_claim_body(slot_tenant_id_load(), claimed_slot_body(opcodes)),
532 ),
533 ]
534}
535
536fn tenant_authorized_claim_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
537 tenant_authorized_body(
538 tenant_id,
539 vec![
540 Node::let_bind(
544 "prev_status",
545 Expr::atomic_compare_exchange(
546 "ring_buffer",
547 Expr::var("status_index"),
548 Expr::u32(slot::PUBLISHED),
549 Expr::u32(slot::CLAIMED),
550 ),
551 ),
552 Node::if_then(
553 Expr::eq(Expr::var("prev_status"), Expr::u32(slot::PUBLISHED)),
554 claimed_body,
555 ),
556 ],
557 )
558}
559
560fn execute_already_claimed_slot_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
561 let mut body = vec![Node::let_bind(
562 "status_index",
563 Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
564 )];
565 body.extend(tenant_authorized_body(tenant_id, claimed_body));
566 body
567}
568
569#[cfg(test)]
570mod tests;