Skip to main content

vyre_runtime/megakernel/builder/
jit.rs

1use super::persistent_lane_prologue;
2use super::{
3    claimed_slot_bindings, direct_slot_base_binding, process_io_requests, slot_tenant_id_load,
4    tenant_authorized_claim_body, wrap_persistent_megakernel_program,
5};
6use crate::megakernel::ir_util::atomic_load_relaxed;
7use crate::megakernel::protocol::{control, slot, STATUS_WORD};
8use vyre_foundation::ir::{Expr, Node, Program};
9
10/// Build the JIT Megakernel IR where payload processor logic is fused into the body stream.
11#[must_use]
12pub fn build_program_jit(workgroup_size_x: u32, payload_processor: &[Node]) -> Program {
13    build_program_jit_slots(workgroup_size_x, workgroup_size_x.max(1), payload_processor)
14}
15
16/// Build the JIT megakernel IR for an explicit number of ring slots.
17#[must_use]
18pub fn build_program_jit_slots(
19    workgroup_size_x: u32,
20    slot_count: u32,
21    payload_processor: &[Node],
22) -> Program {
23    wrap_persistent_megakernel_program(
24        workgroup_size_x,
25        slot_count,
26        persistent_body_jit(workgroup_size_x, payload_processor),
27    )
28}
29
30fn execute_slot_body_jit(payload_processor: &[Node]) -> Vec<Node> {
31    vec![
32        Node::let_bind(
33            "status_index",
34            Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
35        ),
36        Node::let_bind(
37            "observed_status",
38            atomic_load_relaxed("ring_buffer", Expr::var("status_index")),
39        ),
40        Node::if_then(
41            Expr::eq(Expr::var("observed_status"), Expr::u32(slot::PUBLISHED)),
42            tenant_authorized_claim_body(
43                slot_tenant_id_load(),
44                claimed_slot_body_jit(payload_processor),
45            ),
46        ),
47    ]
48}
49
50// ---- JIT variant ----
51
52/// The JIT body that runs once per iteration per lane.
53#[must_use]
54pub fn persistent_body_jit(workgroup_size_x: u32, payload_processor: &[Node]) -> Vec<Node> {
55    let mut body = persistent_lane_prologue(workgroup_size_x);
56    if let Some(body_capacity) = body.len().checked_add(3) {
57        let _ = vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity);
58    }
59    body.push(direct_slot_base_binding());
60    body.push(Node::Block(execute_slot_body_jit(payload_processor)));
61    body.push(Node::Block(process_io_requests()));
62    body
63}
64
65/// Fallible JIT body builder with explicit staging-allocation reporting.
66pub(super) fn try_persistent_body_jit(
67    workgroup_size_x: u32,
68    payload_processor: &[Node],
69) -> Result<Vec<Node>, String> {
70    let mut body = persistent_lane_prologue(workgroup_size_x);
71    let body_capacity = body.len().checked_add(3).ok_or_else(|| {
72        "megakernel JIT body node reservation overflowed usize. Fix: reduce fused payload/body staging before building the JIT megakernel."
73            .to_string()
74    })?;
75    vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity).map_err(|error| {
76        format!(
77            "megakernel JIT body node reservation failed: {error}. Fix: reduce fused payload/body staging before building the JIT megakernel."
78        )
79    })?;
80    body.push(direct_slot_base_binding());
81    body.push(Node::Block(execute_slot_body_jit(payload_processor)));
82    body.push(Node::Block(process_io_requests()));
83    Ok(body)
84}
85
86fn claimed_slot_body_jit(payload_processor: &[Node]) -> Vec<Node> {
87    let mut nodes = claimed_slot_bindings();
88
89    // Wire the statically JIT-compiled rule/payload evaluation graph.
90    nodes.extend(payload_processor.iter().cloned());
91
92    nodes.push(Node::let_bind(
93        "done_prev",
94        Expr::atomic_add("control", Expr::u32(control::DONE_COUNT), Expr::u32(1)),
95    ));
96    nodes.push(Node::store(
97        "ring_buffer",
98        Expr::var("status_index"),
99        Expr::u32(slot::DONE),
100    ));
101    nodes
102}