1use crate::backend::accounting::{
4 checked_add_u64_count as checked_add, checked_mul_u64_count as checked_mul,
5 CudaArithmeticOverflow,
6};
7use crate::megakernel_scheduler::CudaMegakernelGraphShape;
8use vyre_self_substrate::device_resident_token_fact_graph::DeviceResidentTokenFactGraph;
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub struct CudaTokenFactGraphLayout {
13 pub graph_shape: CudaMegakernelGraphShape,
15 pub node_record_bytes: u64,
17 pub edge_record_bytes: u64,
19 pub node_bytes: u64,
21 pub edge_bytes: u64,
23 pub payload_bytes: u64,
25 pub resident_bytes: u64,
27}
28
29#[derive(Clone, Debug, Eq, PartialEq)]
31pub enum CudaTokenFactGraphLayoutError {
32 ZeroRecordWidth {
34 field: &'static str,
36 },
37 ByteCountOverflow {
39 field: &'static str,
41 },
42}
43
44impl std::fmt::Display for CudaTokenFactGraphLayoutError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::ZeroRecordWidth { field } => write!(
48 f,
49 "CUDA token/fact graph adapter received zero {field}. Fix: pass the concrete resident ABI record width."
50 ),
51 Self::ByteCountOverflow { field } => write!(
52 f,
53 "CUDA token/fact graph adapter overflowed while computing {field}. Fix: shard the token/fact graph before resident upload."
54 ),
55 }
56 }
57}
58
59impl std::error::Error for CudaTokenFactGraphLayoutError {}
60
61impl CudaArithmeticOverflow for CudaTokenFactGraphLayoutError {
62 fn arithmetic_overflow(field: &'static str) -> Self {
63 Self::ByteCountOverflow { field }
64 }
65}
66
67pub fn adapt_token_fact_graph_to_cuda_layout(
69 graph: &DeviceResidentTokenFactGraph,
70 node_record_bytes: u64,
71 edge_record_bytes: u64,
72) -> Result<CudaTokenFactGraphLayout, CudaTokenFactGraphLayoutError> {
73 if node_record_bytes == 0 {
74 return Err(CudaTokenFactGraphLayoutError::ZeroRecordWidth {
75 field: "node_record_bytes",
76 });
77 }
78 if edge_record_bytes == 0 {
79 return Err(CudaTokenFactGraphLayoutError::ZeroRecordWidth {
80 field: "edge_record_bytes",
81 });
82 }
83 let node_count = u64::try_from(graph.node_ids.len()).map_err(|_| {
84 CudaTokenFactGraphLayoutError::ByteCountOverflow {
85 field: "node count",
86 }
87 })?;
88 let edge_count = u64::try_from(graph.column_indices.len()).map_err(|_| {
89 CudaTokenFactGraphLayoutError::ByteCountOverflow {
90 field: "edge count",
91 }
92 })?;
93 let node_bytes = checked_mul(node_count, node_record_bytes, "node bytes")?;
94 let edge_bytes = checked_mul(edge_count, edge_record_bytes, "edge bytes")?;
95 let resident_without_payload = checked_add(node_bytes, edge_bytes, "node plus edge bytes")?;
96 let resident_bytes = checked_add(
97 resident_without_payload,
98 graph.payload_bytes,
99 "resident bytes",
100 )?;
101
102 Ok(CudaTokenFactGraphLayout {
103 graph_shape: CudaMegakernelGraphShape {
104 node_count,
105 edge_count,
106 },
107 node_record_bytes,
108 edge_record_bytes,
109 node_bytes,
110 edge_bytes,
111 payload_bytes: graph.payload_bytes,
112 resident_bytes,
113 })
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use crate::megakernel_scheduler::{plan_cuda_megakernel_memory_budget, CudaMegakernelTopology};
120 use vyre_self_substrate::device_resident_token_fact_graph::{
121 plan_device_resident_token_fact_graph, TokenFactEdge, TokenFactEdgeKind, TokenFactNode,
122 TokenFactNodeKind,
123 };
124
125 #[test]
126 fn token_fact_adapter_uses_shared_typed_cuda_arithmetic() {
127 let source = include_str!("token_fact_graph_cuda_adapter.rs");
128
129 assert!(source.contains("checked_add_u64_count as checked_add"));
130 assert!(source.contains("checked_mul_u64_count as checked_mul"));
131 assert!(source.contains("impl CudaArithmeticOverflow for CudaTokenFactGraphLayoutError"));
132 assert!(!source.contains(concat!("fn checked_", "mul(")));
133 assert!(!source.contains(concat!("fn checked_", "add(")));
134 }
135
136 #[test]
137 fn adapter_accounts_for_cuda_resident_token_fact_layout() {
138 let graph = plan_device_resident_token_fact_graph(
139 &[
140 node(1, TokenFactNodeKind::Token, 0, 8),
141 node(2, TokenFactNodeKind::Semantic, 8, 8),
142 node(3, TokenFactNodeKind::Fact, 16, 8),
143 ],
144 &[
145 edge(1, 2, TokenFactEdgeKind::SemanticFact),
146 edge(2, 3, TokenFactEdgeKind::FactDependency),
147 ],
148 24,
149 )
150 .expect("Fix: token/fact graph should pack");
151
152 let cuda = adapt_token_fact_graph_to_cuda_layout(&graph, 32, 16)
153 .expect("Fix: token/fact graph should adapt to CUDA layout");
154
155 assert_eq!(cuda.graph_shape.node_count, 3);
156 assert_eq!(cuda.graph_shape.edge_count, 2);
157 assert_eq!(cuda.node_bytes, 96);
158 assert_eq!(cuda.edge_bytes, 32);
159 assert_eq!(cuda.resident_bytes, 152);
160 let memory = plan_cuda_megakernel_memory_budget(
161 CudaMegakernelTopology::SparseFrontier,
162 cuda.graph_shape,
163 cuda.node_record_bytes,
164 cuda.edge_record_bytes,
165 64,
166 cuda.payload_bytes,
167 16,
168 512,
169 )
170 .expect("Fix: adapted token/fact graph should feed CUDA memory planning");
171 assert_eq!(memory.graph_bytes, 128);
172 }
173
174 #[test]
175 fn adapter_rejects_missing_abi_widths() {
176 let graph = plan_device_resident_token_fact_graph(&[], &[], 0)
177 .expect("Fix: empty graph still has a valid resident layout");
178
179 assert_eq!(
180 adapt_token_fact_graph_to_cuda_layout(&graph, 0, 8)
181 .expect_err("zero node record width should fail"),
182 CudaTokenFactGraphLayoutError::ZeroRecordWidth {
183 field: "node_record_bytes",
184 }
185 );
186 assert_eq!(
187 adapt_token_fact_graph_to_cuda_layout(&graph, 8, 0)
188 .expect_err("zero edge record width should fail"),
189 CudaTokenFactGraphLayoutError::ZeroRecordWidth {
190 field: "edge_record_bytes",
191 }
192 );
193 }
194
195 fn node(
196 id: u32,
197 kind: TokenFactNodeKind,
198 payload_offset: u64,
199 payload_bytes: u64,
200 ) -> TokenFactNode {
201 TokenFactNode {
202 id,
203 kind,
204 payload_offset,
205 payload_bytes,
206 }
207 }
208
209 fn edge(from: u32, to: u32, kind: TokenFactEdgeKind) -> TokenFactEdge {
210 TokenFactEdge { from, to, kind }
211 }
212}