Skip to main content

vyre_driver_cuda/
token_fact_graph_cuda_adapter.rs

1//! CUDA adapter for the unified resident token/fact graph.
2
3use crate::backend::accounting::{
4    checked_add_u64_count as checked_add, checked_mul_u64_count as checked_mul,
5    CudaArithmeticOverflow,
6};
7use crate::megakernel_scheduler::CudaMegakernelGraphShape;
8use vyre_self_substrate::device_resident_token_fact_graph::DeviceResidentTokenFactGraph;
9
10/// CUDA resident byte envelope for the unified compiler/dataflow graph.
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub struct CudaTokenFactGraphLayout {
13    /// Scheduler-visible graph shape.
14    pub graph_shape: CudaMegakernelGraphShape,
15    /// Fixed bytes per resident node record.
16    pub node_record_bytes: u64,
17    /// Fixed bytes per resident edge record.
18    pub edge_record_bytes: u64,
19    /// Bytes for resident node records.
20    pub node_bytes: u64,
21    /// Bytes for resident edge records.
22    pub edge_bytes: u64,
23    /// Bytes for the shared token/fact payload slab.
24    pub payload_bytes: u64,
25    /// Total bytes that must remain device-resident for the layout.
26    pub resident_bytes: u64,
27}
28
29/// CUDA token/fact adapter errors.
30#[derive(Clone, Debug, Eq, PartialEq)]
31pub enum CudaTokenFactGraphLayoutError {
32    /// Record widths must be explicit, non-zero ABI values.
33    ZeroRecordWidth {
34        /// Field that was zero.
35        field: &'static str,
36    },
37    /// Byte arithmetic overflowed.
38    ByteCountOverflow {
39        /// Field being computed.
40        field: &'static str,
41    },
42}
43
44impl std::fmt::Display for CudaTokenFactGraphLayoutError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::ZeroRecordWidth { field } => write!(
48                f,
49                "CUDA token/fact graph adapter received zero {field}. Fix: pass the concrete resident ABI record width."
50            ),
51            Self::ByteCountOverflow { field } => write!(
52                f,
53                "CUDA token/fact graph adapter overflowed while computing {field}. Fix: shard the token/fact graph before resident upload."
54            ),
55        }
56    }
57}
58
59impl std::error::Error for CudaTokenFactGraphLayoutError {}
60
61impl CudaArithmeticOverflow for CudaTokenFactGraphLayoutError {
62    fn arithmetic_overflow(field: &'static str) -> Self {
63        Self::ByteCountOverflow { field }
64    }
65}
66
67/// Convert the unified token/fact graph into CUDA scheduler shape and bytes.
68pub fn adapt_token_fact_graph_to_cuda_layout(
69    graph: &DeviceResidentTokenFactGraph,
70    node_record_bytes: u64,
71    edge_record_bytes: u64,
72) -> Result<CudaTokenFactGraphLayout, CudaTokenFactGraphLayoutError> {
73    if node_record_bytes == 0 {
74        return Err(CudaTokenFactGraphLayoutError::ZeroRecordWidth {
75            field: "node_record_bytes",
76        });
77    }
78    if edge_record_bytes == 0 {
79        return Err(CudaTokenFactGraphLayoutError::ZeroRecordWidth {
80            field: "edge_record_bytes",
81        });
82    }
83    let node_count = u64::try_from(graph.node_ids.len()).map_err(|_| {
84        CudaTokenFactGraphLayoutError::ByteCountOverflow {
85            field: "node count",
86        }
87    })?;
88    let edge_count = u64::try_from(graph.column_indices.len()).map_err(|_| {
89        CudaTokenFactGraphLayoutError::ByteCountOverflow {
90            field: "edge count",
91        }
92    })?;
93    let node_bytes = checked_mul(node_count, node_record_bytes, "node bytes")?;
94    let edge_bytes = checked_mul(edge_count, edge_record_bytes, "edge bytes")?;
95    let resident_without_payload = checked_add(node_bytes, edge_bytes, "node plus edge bytes")?;
96    let resident_bytes = checked_add(
97        resident_without_payload,
98        graph.payload_bytes,
99        "resident bytes",
100    )?;
101
102    Ok(CudaTokenFactGraphLayout {
103        graph_shape: CudaMegakernelGraphShape {
104            node_count,
105            edge_count,
106        },
107        node_record_bytes,
108        edge_record_bytes,
109        node_bytes,
110        edge_bytes,
111        payload_bytes: graph.payload_bytes,
112        resident_bytes,
113    })
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::megakernel_scheduler::{plan_cuda_megakernel_memory_budget, CudaMegakernelTopology};
120    use vyre_self_substrate::device_resident_token_fact_graph::{
121        plan_device_resident_token_fact_graph, TokenFactEdge, TokenFactEdgeKind, TokenFactNode,
122        TokenFactNodeKind,
123    };
124
125    #[test]
126    fn token_fact_adapter_uses_shared_typed_cuda_arithmetic() {
127        let source = include_str!("token_fact_graph_cuda_adapter.rs");
128
129        assert!(source.contains("checked_add_u64_count as checked_add"));
130        assert!(source.contains("checked_mul_u64_count as checked_mul"));
131        assert!(source.contains("impl CudaArithmeticOverflow for CudaTokenFactGraphLayoutError"));
132        assert!(!source.contains(concat!("fn checked_", "mul(")));
133        assert!(!source.contains(concat!("fn checked_", "add(")));
134    }
135
136    #[test]
137    fn adapter_accounts_for_cuda_resident_token_fact_layout() {
138        let graph = plan_device_resident_token_fact_graph(
139            &[
140                node(1, TokenFactNodeKind::Token, 0, 8),
141                node(2, TokenFactNodeKind::Semantic, 8, 8),
142                node(3, TokenFactNodeKind::Fact, 16, 8),
143            ],
144            &[
145                edge(1, 2, TokenFactEdgeKind::SemanticFact),
146                edge(2, 3, TokenFactEdgeKind::FactDependency),
147            ],
148            24,
149        )
150        .expect("Fix: token/fact graph should pack");
151
152        let cuda = adapt_token_fact_graph_to_cuda_layout(&graph, 32, 16)
153            .expect("Fix: token/fact graph should adapt to CUDA layout");
154
155        assert_eq!(cuda.graph_shape.node_count, 3);
156        assert_eq!(cuda.graph_shape.edge_count, 2);
157        assert_eq!(cuda.node_bytes, 96);
158        assert_eq!(cuda.edge_bytes, 32);
159        assert_eq!(cuda.resident_bytes, 152);
160        let memory = plan_cuda_megakernel_memory_budget(
161            CudaMegakernelTopology::SparseFrontier,
162            cuda.graph_shape,
163            cuda.node_record_bytes,
164            cuda.edge_record_bytes,
165            64,
166            cuda.payload_bytes,
167            16,
168            512,
169        )
170        .expect("Fix: adapted token/fact graph should feed CUDA memory planning");
171        assert_eq!(memory.graph_bytes, 128);
172    }
173
174    #[test]
175    fn adapter_rejects_missing_abi_widths() {
176        let graph = plan_device_resident_token_fact_graph(&[], &[], 0)
177            .expect("Fix: empty graph still has a valid resident layout");
178
179        assert_eq!(
180            adapt_token_fact_graph_to_cuda_layout(&graph, 0, 8)
181                .expect_err("zero node record width should fail"),
182            CudaTokenFactGraphLayoutError::ZeroRecordWidth {
183                field: "node_record_bytes",
184            }
185        );
186        assert_eq!(
187            adapt_token_fact_graph_to_cuda_layout(&graph, 8, 0)
188                .expect_err("zero edge record width should fail"),
189            CudaTokenFactGraphLayoutError::ZeroRecordWidth {
190                field: "edge_record_bytes",
191            }
192        );
193    }
194
195    fn node(
196        id: u32,
197        kind: TokenFactNodeKind,
198        payload_offset: u64,
199        payload_bytes: u64,
200    ) -> TokenFactNode {
201        TokenFactNode {
202            id,
203            kind,
204            payload_offset,
205            payload_bytes,
206        }
207    }
208
209    fn edge(from: u32, to: u32, kind: TokenFactEdgeKind) -> TokenFactEdge {
210        TokenFactEdge { from, to, kind }
211    }
212}