1use rustc_hash::FxHashSet;
9
10use crate::reservation_policy::ReservationPolicy;
11
12const LAUNCH_FUSION_RESERVATION: ReservationPolicy = ReservationPolicy::new(
13 "adjacent launch fusion",
14 "shard adjacent stages before fusion planning",
15);
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub struct LaunchFusionStage {
20 pub id: u32,
22 pub layout_hash: u64,
24 pub input_bytes: u64,
26 pub output_bytes: u64,
28 pub scratch_bytes: u64,
30 pub requires_host_materialization: bool,
32}
33
34#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct LaunchFusionGroup {
37 pub stage_ids: Vec<u32>,
39 pub layout_hash: u64,
41 pub required_bytes: u64,
43 pub avoided_intermediate_bytes: u64,
45}
46
47#[derive(Clone, Debug, Eq, PartialEq)]
49pub struct LaunchFusionPlan {
50 pub groups: Vec<LaunchFusionGroup>,
52 pub launch_count: u32,
54 pub avoided_launches: u32,
56 pub avoided_intermediate_bytes: u64,
58}
59
60#[derive(Debug, Default)]
62pub struct LaunchFusionScratch {
63 ids: FxHashSet<u32>,
64}
65
66impl LaunchFusionScratch {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 ids: FxHashSet::default(),
72 }
73 }
74
75 pub fn try_with_capacity(stage_count: usize) -> Result<Self, LaunchFusionError> {
82 let mut scratch = Self::new();
83 scratch.try_reserve_ids(stage_count)?;
84 Ok(scratch)
85 }
86
87 fn try_reserve_ids(&mut self, stage_count: usize) -> Result<(), LaunchFusionError> {
88 LAUNCH_FUSION_RESERVATION
89 .reserve_hash_set_to_capacity(&mut self.ids, stage_count, "duplicate stage ids")
90 .map_err(|error| LaunchFusionError::StorageReserveFailed {
91 field: "duplicate stage ids",
92 requested: stage_count,
93 message: error.to_string(),
94 })
95 }
96
97 #[must_use]
99 pub fn id_capacity(&self) -> usize {
100 self.ids.capacity()
101 }
102}
103
104#[derive(Clone, Debug, Eq, PartialEq)]
106pub enum LaunchFusionError {
107 DuplicateStage {
109 id: u32,
111 },
112 ZeroBudget,
114 ByteCountOverflow {
116 field: &'static str,
118 },
119 StageOverBudget {
121 id: u32,
123 required_bytes: u64,
125 budget_bytes: u64,
127 },
128 StorageReserveFailed {
130 field: &'static str,
132 requested: usize,
134 message: String,
136 },
137}
138
139impl std::fmt::Display for LaunchFusionError {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 Self::DuplicateStage { id } => write!(
143 f,
144 "Launch fusion received duplicate stage id {id}. Fix: emit unique stage ids before fusion planning."
145 ),
146 Self::ZeroBudget => write!(
147 f,
148 "Launch fusion received a zero byte budget. Fix: pass an explicit device-memory budget before planning fusion."
149 ),
150 Self::ByteCountOverflow { field } => write!(
151 f,
152 "Launch fusion overflowed while computing {field}. Fix: shard adjacent stages before launch fusion planning."
153 ),
154 Self::StageOverBudget {
155 id,
156 required_bytes,
157 budget_bytes,
158 } => write!(
159 f,
160 "Launch fusion stage {id} requires {required_bytes} bytes but budget allows {budget_bytes}. Fix: shard the stage or raise the explicit fusion budget."
161 ),
162 Self::StorageReserveFailed {
163 field,
164 requested,
165 message,
166 } => write!(
167 f,
168 "Launch fusion could not reserve {requested} {field} entries: {message}. Fix: shard adjacent stages before fusion planning."
169 ),
170 }
171 }
172}
173
174impl std::error::Error for LaunchFusionError {}
175
176pub fn plan_launch_fusion(
184 stages: &[LaunchFusionStage],
185 max_group_bytes: u64,
186) -> Result<LaunchFusionPlan, LaunchFusionError> {
187 let mut scratch = LaunchFusionScratch::try_with_capacity(stages.len())?;
188 plan_launch_fusion_with_scratch(stages, max_group_bytes, &mut scratch)
189}
190
191pub fn plan_launch_fusion_with_scratch(
199 stages: &[LaunchFusionStage],
200 max_group_bytes: u64,
201 scratch: &mut LaunchFusionScratch,
202) -> Result<LaunchFusionPlan, LaunchFusionError> {
203 if max_group_bytes == 0 {
204 return Err(LaunchFusionError::ZeroBudget);
205 }
206 if stages.is_empty() {
207 return Ok(LaunchFusionPlan {
208 groups: Vec::new(),
209 launch_count: 0,
210 avoided_launches: 0,
211 avoided_intermediate_bytes: 0,
212 });
213 }
214 if stages.len() == 1 {
215 let group = singleton_group_with_capacity(stages[0], 1)?;
216 if group.required_bytes > max_group_bytes {
217 return Err(LaunchFusionError::StageOverBudget {
218 id: stages[0].id,
219 required_bytes: group.required_bytes,
220 budget_bytes: max_group_bytes,
221 });
222 }
223 let mut groups = reserved_vec(1, "fusion groups")?;
224 groups.push(group);
225 return Ok(LaunchFusionPlan {
226 groups,
227 launch_count: 1,
228 avoided_launches: 0,
229 avoided_intermediate_bytes: 0,
230 });
231 }
232
233 scratch.ids.clear();
234 if stages.len() <= 8 {
235 for i in 0..stages.len() {
236 let current = stages[i].id;
237 if stages[..i].iter().any(|prev| prev.id == current) {
238 return Err(LaunchFusionError::DuplicateStage { id: current });
239 }
240 }
241 } else {
242 scratch.try_reserve_ids(stages.len())?;
243 for stage in stages {
244 if !scratch.ids.insert(stage.id) {
245 return Err(LaunchFusionError::DuplicateStage { id: stage.id });
246 }
247 }
248 }
249
250 let mut groups = reserved_vec(stages.len(), "fusion groups")?;
251 let mut index = 0;
252 while index < stages.len() {
253 let remaining_stage_count = stages.len() - index;
254 let mut group = singleton_group_with_capacity(stages[index], remaining_stage_count)?;
255 if group.required_bytes > max_group_bytes {
256 return Err(LaunchFusionError::StageOverBudget {
257 id: stages[index].id,
258 required_bytes: group.required_bytes,
259 budget_bytes: max_group_bytes,
260 });
261 }
262 let mut cursor = index + 1;
263 while cursor < stages.len() && can_append_to_group(&group, stages[cursor], max_group_bytes)?
264 {
265 let previous_output = stages[cursor - 1].output_bytes;
266 group.required_bytes = fused_required_bytes(&group, stages[cursor])?;
267 group.avoided_intermediate_bytes = checked_add_u64(
268 group.avoided_intermediate_bytes,
269 previous_output,
270 "avoided intermediate bytes",
271 )?;
272 group.stage_ids.push(stages[cursor].id);
273 cursor += 1;
274 }
275 groups.push(group);
276 index = cursor;
277 }
278
279 let launch_count =
280 u32::try_from(groups.len()).map_err(|_| LaunchFusionError::ByteCountOverflow {
281 field: "launch count",
282 })?;
283 let avoided_launches = u32::try_from(stages.len() - groups.len()).map_err(|_| {
284 LaunchFusionError::ByteCountOverflow {
285 field: "avoided launches",
286 }
287 })?;
288 let mut avoided_intermediate_bytes = 0_u64;
289 for group in &groups {
290 avoided_intermediate_bytes = checked_add_u64(
291 avoided_intermediate_bytes,
292 group.avoided_intermediate_bytes,
293 "total avoided intermediate bytes",
294 )?;
295 }
296
297 Ok(LaunchFusionPlan {
298 groups,
299 launch_count,
300 avoided_launches,
301 avoided_intermediate_bytes,
302 })
303}
304
305fn reserved_vec<T>(capacity: usize, field: &'static str) -> Result<Vec<T>, LaunchFusionError> {
306 LAUNCH_FUSION_RESERVATION
307 .reserved_vec(capacity, field)
308 .map_err(|error| LaunchFusionError::StorageReserveFailed {
309 field,
310 requested: capacity,
311 message: error.to_string(),
312 })
313}
314
315fn singleton_group_with_capacity(
316 stage: LaunchFusionStage,
317 stage_id_capacity: usize,
318) -> Result<LaunchFusionGroup, LaunchFusionError> {
319 let mut stage_ids = reserved_vec(stage_id_capacity.max(1), "fusion group stage ids")?;
320 stage_ids.push(stage.id);
321 Ok(LaunchFusionGroup {
322 stage_ids,
323 layout_hash: stage.layout_hash,
324 required_bytes: stage_required_bytes(stage)?,
325 avoided_intermediate_bytes: 0,
326 })
327}
328
329fn can_append_to_group(
330 group: &LaunchFusionGroup,
331 stage: LaunchFusionStage,
332 max_group_bytes: u64,
333) -> Result<bool, LaunchFusionError> {
334 if stage.requires_host_materialization || stage.layout_hash != group.layout_hash {
335 return Ok(false);
336 }
337 Ok(fused_required_bytes(group, stage)? <= max_group_bytes)
338}
339
340fn fused_required_bytes(
341 group: &LaunchFusionGroup,
342 stage: LaunchFusionStage,
343) -> Result<u64, LaunchFusionError> {
344 checked_add_u64(
345 group.required_bytes,
346 stage.scratch_bytes,
347 "fused scratch bytes",
348 )
349 .and_then(|bytes| checked_add_u64(bytes, stage.output_bytes, "fused output bytes"))
350}
351
352fn stage_required_bytes(stage: LaunchFusionStage) -> Result<u64, LaunchFusionError> {
353 let input_plus_output =
354 checked_add_u64(stage.input_bytes, stage.output_bytes, "stage io bytes")?;
355 checked_add_u64(
356 input_plus_output,
357 stage.scratch_bytes,
358 "stage required bytes",
359 )
360}
361
362fn checked_add_u64(left: u64, right: u64, field: &'static str) -> Result<u64, LaunchFusionError> {
363 left.checked_add(right)
364 .ok_or(LaunchFusionError::ByteCountOverflow { field })
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn launch_fusion_groups_adjacent_compatible_stages() {
373 let plan = plan_launch_fusion(
374 &[
375 stage(1, 7, 64, 32, 8, false),
376 stage(2, 7, 32, 48, 8, false),
377 stage(3, 7, 48, 16, 8, false),
378 ],
379 256,
380 )
381 .expect("Fix: compatible stages should fuse");
382
383 assert_eq!(plan.launch_count, 1);
384 assert_eq!(plan.avoided_launches, 2);
385 assert_eq!(plan.groups[0].stage_ids, vec![1, 2, 3]);
386 assert_eq!(plan.avoided_intermediate_bytes, 80);
387 }
388
389 #[test]
390 fn launch_fusion_splits_on_layout_host_boundary_and_budget() {
391 let plan = plan_launch_fusion(
392 &[
393 stage(1, 7, 64, 32, 8, false),
394 stage(2, 8, 32, 48, 8, false),
395 stage(3, 8, 48, 16, 8, true),
396 stage(4, 9, 16, 16, 8, false),
397 ],
398 128,
399 )
400 .expect("Fix: incompatible stages should split deterministically");
401
402 assert_eq!(plan.launch_count, 4);
403 assert_eq!(plan.avoided_launches, 0);
404 assert_eq!(plan.groups[0].stage_ids, vec![1]);
405 assert_eq!(plan.groups[1].stage_ids, vec![2]);
406 assert_eq!(plan.groups[2].stage_ids, vec![3]);
407 assert_eq!(plan.groups[3].stage_ids, vec![4]);
408 }
409
410 #[test]
411 fn launch_fusion_rejects_invalid_inputs() {
412 assert_eq!(
413 plan_launch_fusion(&[stage(1, 7, 1, 1, 1, false)], 0)
414 .expect_err("zero budget should fail"),
415 LaunchFusionError::ZeroBudget
416 );
417 assert_eq!(
418 plan_launch_fusion(
419 &[stage(1, 7, 1, 1, 1, false), stage(1, 7, 1, 1, 1, false),],
420 128,
421 )
422 .expect_err("duplicate stages should fail"),
423 LaunchFusionError::DuplicateStage { id: 1 }
424 );
425 assert_eq!(
426 plan_launch_fusion(&[stage(9, 7, 64, 32, 64, false)], 128)
427 .expect_err("single over-budget stage should fail"),
428 LaunchFusionError::StageOverBudget {
429 id: 9,
430 required_bytes: 160,
431 budget_bytes: 128,
432 }
433 );
434 }
435
436 #[test]
437 fn generated_launch_fusion_preserves_budget_and_order_contract() {
438 for seed in 0..4096_u64 {
439 let stages = generated_stages(seed);
440 let budget = 96 + (seed % 512);
441 let plan = plan_launch_fusion(&stages, budget)
442 .or_else(|error| match error {
443 LaunchFusionError::StageOverBudget { .. } => Ok(LaunchFusionPlan {
444 groups: Vec::new(),
445 launch_count: 0,
446 avoided_launches: 0,
447 avoided_intermediate_bytes: 0,
448 }),
449 other => Err(other),
450 })
451 .expect(
452 "Fix: generated launch fusion should only reject singleton over-budget stages",
453 );
454 if plan.groups.is_empty() {
455 continue;
456 }
457
458 let flattened = plan
459 .groups
460 .iter()
461 .flat_map(|group| group.stage_ids.iter().copied())
462 .collect::<Vec<_>>();
463 assert_eq!(
464 flattened,
465 stages.iter().map(|stage| stage.id).collect::<Vec<_>>(),
466 "Fix: launch fusion must preserve original stage order for seed {seed}."
467 );
468 assert_eq!(
469 usize::try_from(plan.launch_count).expect("Fix: plan launch_count must fit usize on this platform; reject oversized plans upstream - launch_count fits usize"),
470 plan.groups.len(),
471 "Fix: launch_count must match group count for seed {seed}."
472 );
473 assert_eq!(
474 usize::try_from(plan.avoided_launches).expect("Fix: avoided_launches must fit usize; clamp or reject plan before fusion stats - avoided_launches fits usize"),
475 stages.len() - plan.groups.len(),
476 "Fix: avoided_launches must match fused group reduction for seed {seed}."
477 );
478 for group in &plan.groups {
479 assert!(
480 group.required_bytes <= budget,
481 "Fix: fused group exceeded explicit budget for seed {seed}."
482 );
483 }
484 }
485 }
486
487 #[test]
488 fn launch_fusion_reuses_caller_owned_duplicate_detection_scratch() {
489 let mut scratch =
490 LaunchFusionScratch::try_with_capacity(64).expect("Fix: fusion scratch should reserve");
491 let wide = (0..64)
492 .map(|id| stage(id, 7, 16, 16, 4, false))
493 .collect::<Vec<_>>();
494 let first = plan_launch_fusion_with_scratch(&wide, 8_192, &mut scratch)
495 .expect("Fix: wide compatible stages should fuse");
496 let id_capacity = scratch.id_capacity();
497
498 assert_eq!(first.launch_count, 1);
499 assert_eq!(first.avoided_launches, 63);
500
501 let second = plan_launch_fusion_with_scratch(
502 &[
503 stage(10, 7, 64, 32, 8, false),
504 stage(11, 8, 32, 48, 8, false),
505 ],
506 512,
507 &mut scratch,
508 )
509 .expect("Fix: smaller incompatible stages should reuse duplicate-detection scratch");
510
511 assert_eq!(second.launch_count, 2);
512 assert!(scratch.id_capacity() >= id_capacity);
513 }
514
515 #[test]
516 fn launch_fusion_staging_reserves_fallibly() {
517 let src = include_str!("launch_fusion.rs");
518
519 assert!(
520 src.contains("LaunchFusionScratch::try_with_capacity(stages.len())?")
521 && src.contains("scratch.try_reserve_ids(stages.len())?")
522 && src.contains("ReservationPolicy")
523 && src.contains("StorageReserveFailed"),
524 "Fix: launch fusion staging must use shared fallible reservations under scale pressure."
525 );
526 assert!(
527 !src.contains(concat!("FxHashSet::with_capacity", "_and_hasher"))
528 && !src.contains(concat!("Vec::with_capacity", "(stages.len())"))
529 && !src.contains(concat!("groups: vec![", "group]"))
530 && !src.contains(concat!("stage_ids: vec![", "stage.id]"))
531 && !src.contains(concat!("scratch.ids", ".reserve(stages.len())")),
532 "Fix: launch fusion release planning must not use infallible staging allocation."
533 );
534 }
535
536 fn generated_stages(seed: u64) -> Vec<LaunchFusionStage> {
537 let count = 1 + (seed as usize % 24);
538 let mut stages = Vec::with_capacity(count);
539 let mut state = seed ^ 0xF051_1A4A_7E57_0001;
540 for index in 0..count {
541 stages.push(stage(
542 index as u32,
543 next_u64(&mut state) % 5,
544 1 + (next_u64(&mut state) % 48),
545 1 + (next_u64(&mut state) % 48),
546 next_u64(&mut state) % 24,
547 next_u64(&mut state) % 11 == 0,
548 ));
549 }
550 stages
551 }
552
553 fn stage(
554 id: u32,
555 layout_hash: u64,
556 input_bytes: u64,
557 output_bytes: u64,
558 scratch_bytes: u64,
559 requires_host_materialization: bool,
560 ) -> LaunchFusionStage {
561 LaunchFusionStage {
562 id,
563 layout_hash,
564 input_bytes,
565 output_bytes,
566 scratch_bytes,
567 requires_host_materialization,
568 }
569 }
570
571 fn next_u64(state: &mut u64) -> u64 {
572 let mut x = *state;
573 x ^= x << 13;
574 x ^= x >> 7;
575 x ^= x << 17;
576 *state = x;
577 x
578 }
579}