Skip to main content

vyre_driver/
launch_fusion.rs

1//! Backend-neutral adjacent-stage launch fusion planning.
2//!
3//! Backends that dispatch adjacent stages with compatible memory layouts can
4//! fuse them into fewer launches when the fused memory envelope fits an
5//! explicit budget. This module owns the pure planning algorithm so CUDA and
6//! future backends do not carry divergent launch-fusion logic.
7
8use rustc_hash::FxHashSet;
9
10use crate::reservation_policy::ReservationPolicy;
11
12const LAUNCH_FUSION_RESERVATION: ReservationPolicy = ReservationPolicy::new(
13    "adjacent launch fusion",
14    "shard adjacent stages before fusion planning",
15);
16
17/// One adjacent backend stage considered for launch fusion.
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub struct LaunchFusionStage {
20    /// Stable stage id.
21    pub id: u32,
22    /// Memory-layout compatibility hash.
23    pub layout_hash: u64,
24    /// Input bytes consumed by this stage.
25    pub input_bytes: u64,
26    /// Output bytes produced by this stage.
27    pub output_bytes: u64,
28    /// Scratch bytes required by this stage.
29    pub scratch_bytes: u64,
30    /// Whether this stage boundary requires host-visible materialization.
31    pub requires_host_materialization: bool,
32}
33
34/// One fused adjacent-stage launch group.
35#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct LaunchFusionGroup {
37    /// Stage ids included in the fused group.
38    pub stage_ids: Vec<u32>,
39    /// Shared layout hash for the group.
40    pub layout_hash: u64,
41    /// Peak bytes required by the fused group.
42    pub required_bytes: u64,
43    /// Host-visible intermediate bytes avoided by fusion.
44    pub avoided_intermediate_bytes: u64,
45}
46
47/// Complete adjacent-stage launch fusion plan.
48#[derive(Clone, Debug, Eq, PartialEq)]
49pub struct LaunchFusionPlan {
50    /// Fused or singleton groups in original stage order.
51    pub groups: Vec<LaunchFusionGroup>,
52    /// Number of backend launches after fusion.
53    pub launch_count: u32,
54    /// Number of launches removed by fusion.
55    pub avoided_launches: u32,
56    /// Total host-visible intermediate bytes avoided.
57    pub avoided_intermediate_bytes: u64,
58}
59
60/// Caller-owned scratch for repeated launch-fusion planning.
61#[derive(Debug, Default)]
62pub struct LaunchFusionScratch {
63    ids: FxHashSet<u32>,
64}
65
66impl LaunchFusionScratch {
67    /// Create empty reusable launch-fusion scratch.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            ids: FxHashSet::default(),
72        }
73    }
74
75    /// Allocate reusable launch-fusion scratch for a known stage count.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`LaunchFusionError`] when duplicate-detection storage cannot
80    /// be reserved.
81    pub fn try_with_capacity(stage_count: usize) -> Result<Self, LaunchFusionError> {
82        let mut scratch = Self::new();
83        scratch.try_reserve_ids(stage_count)?;
84        Ok(scratch)
85    }
86
87    fn try_reserve_ids(&mut self, stage_count: usize) -> Result<(), LaunchFusionError> {
88        LAUNCH_FUSION_RESERVATION
89            .reserve_hash_set_to_capacity(&mut self.ids, stage_count, "duplicate stage ids")
90            .map_err(|error| LaunchFusionError::StorageReserveFailed {
91                field: "duplicate stage ids",
92                requested: stage_count,
93                message: error.to_string(),
94            })
95    }
96
97    /// Retained duplicate-detection capacity.
98    #[must_use]
99    pub fn id_capacity(&self) -> usize {
100        self.ids.capacity()
101    }
102}
103
104/// Launch fusion planning errors.
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub enum LaunchFusionError {
107    /// Duplicate stage id.
108    DuplicateStage {
109        /// Duplicate id.
110        id: u32,
111    },
112    /// Explicit fusion budget cannot be zero.
113    ZeroBudget,
114    /// Byte arithmetic overflowed.
115    ByteCountOverflow {
116        /// Field being computed.
117        field: &'static str,
118    },
119    /// One stage cannot fit the explicit fusion budget even without fusion.
120    StageOverBudget {
121        /// Stage id.
122        id: u32,
123        /// Required bytes for the singleton stage.
124        required_bytes: u64,
125        /// Caller-provided budget.
126        budget_bytes: u64,
127    },
128    /// Planner storage could not be reserved.
129    StorageReserveFailed {
130        /// Field being reserved.
131        field: &'static str,
132        /// Number of entries requested.
133        requested: usize,
134        /// Allocator error text.
135        message: String,
136    },
137}
138
139impl std::fmt::Display for LaunchFusionError {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            Self::DuplicateStage { id } => write!(
143                f,
144                "Launch fusion received duplicate stage id {id}. Fix: emit unique stage ids before fusion planning."
145            ),
146            Self::ZeroBudget => write!(
147                f,
148                "Launch fusion received a zero byte budget. Fix: pass an explicit device-memory budget before planning fusion."
149            ),
150            Self::ByteCountOverflow { field } => write!(
151                f,
152                "Launch fusion overflowed while computing {field}. Fix: shard adjacent stages before launch fusion planning."
153            ),
154            Self::StageOverBudget {
155                id,
156                required_bytes,
157                budget_bytes,
158            } => write!(
159                f,
160                "Launch fusion stage {id} requires {required_bytes} bytes but budget allows {budget_bytes}. Fix: shard the stage or raise the explicit fusion budget."
161            ),
162            Self::StorageReserveFailed {
163                field,
164                requested,
165                message,
166            } => write!(
167                f,
168                "Launch fusion could not reserve {requested} {field} entries: {message}. Fix: shard adjacent stages before fusion planning."
169            ),
170        }
171    }
172}
173
174impl std::error::Error for LaunchFusionError {}
175
176/// Plan adjacent launch fusion under layout and memory constraints.
177///
178/// # Errors
179///
180/// Returns [`LaunchFusionError`] when inputs are invalid, byte arithmetic
181/// overflows, staging allocation fails, or any singleton stage exceeds the
182/// explicit budget.
183pub fn plan_launch_fusion(
184    stages: &[LaunchFusionStage],
185    max_group_bytes: u64,
186) -> Result<LaunchFusionPlan, LaunchFusionError> {
187    let mut scratch = LaunchFusionScratch::try_with_capacity(stages.len())?;
188    plan_launch_fusion_with_scratch(stages, max_group_bytes, &mut scratch)
189}
190
191/// Plan adjacent launch fusion using caller-owned temporary storage.
192///
193/// # Errors
194///
195/// Returns [`LaunchFusionError`] when inputs are invalid, byte arithmetic
196/// overflows, staging allocation fails, or any singleton stage exceeds the
197/// explicit budget.
198pub fn plan_launch_fusion_with_scratch(
199    stages: &[LaunchFusionStage],
200    max_group_bytes: u64,
201    scratch: &mut LaunchFusionScratch,
202) -> Result<LaunchFusionPlan, LaunchFusionError> {
203    if max_group_bytes == 0 {
204        return Err(LaunchFusionError::ZeroBudget);
205    }
206    if stages.is_empty() {
207        return Ok(LaunchFusionPlan {
208            groups: Vec::new(),
209            launch_count: 0,
210            avoided_launches: 0,
211            avoided_intermediate_bytes: 0,
212        });
213    }
214    if stages.len() == 1 {
215        let group = singleton_group_with_capacity(stages[0], 1)?;
216        if group.required_bytes > max_group_bytes {
217            return Err(LaunchFusionError::StageOverBudget {
218                id: stages[0].id,
219                required_bytes: group.required_bytes,
220                budget_bytes: max_group_bytes,
221            });
222        }
223        let mut groups = reserved_vec(1, "fusion groups")?;
224        groups.push(group);
225        return Ok(LaunchFusionPlan {
226            groups,
227            launch_count: 1,
228            avoided_launches: 0,
229            avoided_intermediate_bytes: 0,
230        });
231    }
232
233    scratch.ids.clear();
234    if stages.len() <= 8 {
235        for i in 0..stages.len() {
236            let current = stages[i].id;
237            if stages[..i].iter().any(|prev| prev.id == current) {
238                return Err(LaunchFusionError::DuplicateStage { id: current });
239            }
240        }
241    } else {
242        scratch.try_reserve_ids(stages.len())?;
243        for stage in stages {
244            if !scratch.ids.insert(stage.id) {
245                return Err(LaunchFusionError::DuplicateStage { id: stage.id });
246            }
247        }
248    }
249
250    let mut groups = reserved_vec(stages.len(), "fusion groups")?;
251    let mut index = 0;
252    while index < stages.len() {
253        let remaining_stage_count = stages.len() - index;
254        let mut group = singleton_group_with_capacity(stages[index], remaining_stage_count)?;
255        if group.required_bytes > max_group_bytes {
256            return Err(LaunchFusionError::StageOverBudget {
257                id: stages[index].id,
258                required_bytes: group.required_bytes,
259                budget_bytes: max_group_bytes,
260            });
261        }
262        let mut cursor = index + 1;
263        while cursor < stages.len() && can_append_to_group(&group, stages[cursor], max_group_bytes)?
264        {
265            let previous_output = stages[cursor - 1].output_bytes;
266            group.required_bytes = fused_required_bytes(&group, stages[cursor])?;
267            group.avoided_intermediate_bytes = checked_add_u64(
268                group.avoided_intermediate_bytes,
269                previous_output,
270                "avoided intermediate bytes",
271            )?;
272            group.stage_ids.push(stages[cursor].id);
273            cursor += 1;
274        }
275        groups.push(group);
276        index = cursor;
277    }
278
279    let launch_count =
280        u32::try_from(groups.len()).map_err(|_| LaunchFusionError::ByteCountOverflow {
281            field: "launch count",
282        })?;
283    let avoided_launches = u32::try_from(stages.len() - groups.len()).map_err(|_| {
284        LaunchFusionError::ByteCountOverflow {
285            field: "avoided launches",
286        }
287    })?;
288    let mut avoided_intermediate_bytes = 0_u64;
289    for group in &groups {
290        avoided_intermediate_bytes = checked_add_u64(
291            avoided_intermediate_bytes,
292            group.avoided_intermediate_bytes,
293            "total avoided intermediate bytes",
294        )?;
295    }
296
297    Ok(LaunchFusionPlan {
298        groups,
299        launch_count,
300        avoided_launches,
301        avoided_intermediate_bytes,
302    })
303}
304
305fn reserved_vec<T>(capacity: usize, field: &'static str) -> Result<Vec<T>, LaunchFusionError> {
306    LAUNCH_FUSION_RESERVATION
307        .reserved_vec(capacity, field)
308        .map_err(|error| LaunchFusionError::StorageReserveFailed {
309            field,
310            requested: capacity,
311            message: error.to_string(),
312        })
313}
314
315fn singleton_group_with_capacity(
316    stage: LaunchFusionStage,
317    stage_id_capacity: usize,
318) -> Result<LaunchFusionGroup, LaunchFusionError> {
319    let mut stage_ids = reserved_vec(stage_id_capacity.max(1), "fusion group stage ids")?;
320    stage_ids.push(stage.id);
321    Ok(LaunchFusionGroup {
322        stage_ids,
323        layout_hash: stage.layout_hash,
324        required_bytes: stage_required_bytes(stage)?,
325        avoided_intermediate_bytes: 0,
326    })
327}
328
329fn can_append_to_group(
330    group: &LaunchFusionGroup,
331    stage: LaunchFusionStage,
332    max_group_bytes: u64,
333) -> Result<bool, LaunchFusionError> {
334    if stage.requires_host_materialization || stage.layout_hash != group.layout_hash {
335        return Ok(false);
336    }
337    Ok(fused_required_bytes(group, stage)? <= max_group_bytes)
338}
339
340fn fused_required_bytes(
341    group: &LaunchFusionGroup,
342    stage: LaunchFusionStage,
343) -> Result<u64, LaunchFusionError> {
344    checked_add_u64(
345        group.required_bytes,
346        stage.scratch_bytes,
347        "fused scratch bytes",
348    )
349    .and_then(|bytes| checked_add_u64(bytes, stage.output_bytes, "fused output bytes"))
350}
351
352fn stage_required_bytes(stage: LaunchFusionStage) -> Result<u64, LaunchFusionError> {
353    let input_plus_output =
354        checked_add_u64(stage.input_bytes, stage.output_bytes, "stage io bytes")?;
355    checked_add_u64(
356        input_plus_output,
357        stage.scratch_bytes,
358        "stage required bytes",
359    )
360}
361
362fn checked_add_u64(left: u64, right: u64, field: &'static str) -> Result<u64, LaunchFusionError> {
363    left.checked_add(right)
364        .ok_or(LaunchFusionError::ByteCountOverflow { field })
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn launch_fusion_groups_adjacent_compatible_stages() {
373        let plan = plan_launch_fusion(
374            &[
375                stage(1, 7, 64, 32, 8, false),
376                stage(2, 7, 32, 48, 8, false),
377                stage(3, 7, 48, 16, 8, false),
378            ],
379            256,
380        )
381        .expect("Fix: compatible stages should fuse");
382
383        assert_eq!(plan.launch_count, 1);
384        assert_eq!(plan.avoided_launches, 2);
385        assert_eq!(plan.groups[0].stage_ids, vec![1, 2, 3]);
386        assert_eq!(plan.avoided_intermediate_bytes, 80);
387    }
388
389    #[test]
390    fn launch_fusion_splits_on_layout_host_boundary_and_budget() {
391        let plan = plan_launch_fusion(
392            &[
393                stage(1, 7, 64, 32, 8, false),
394                stage(2, 8, 32, 48, 8, false),
395                stage(3, 8, 48, 16, 8, true),
396                stage(4, 9, 16, 16, 8, false),
397            ],
398            128,
399        )
400        .expect("Fix: incompatible stages should split deterministically");
401
402        assert_eq!(plan.launch_count, 4);
403        assert_eq!(plan.avoided_launches, 0);
404        assert_eq!(plan.groups[0].stage_ids, vec![1]);
405        assert_eq!(plan.groups[1].stage_ids, vec![2]);
406        assert_eq!(plan.groups[2].stage_ids, vec![3]);
407        assert_eq!(plan.groups[3].stage_ids, vec![4]);
408    }
409
410    #[test]
411    fn launch_fusion_rejects_invalid_inputs() {
412        assert_eq!(
413            plan_launch_fusion(&[stage(1, 7, 1, 1, 1, false)], 0)
414                .expect_err("zero budget should fail"),
415            LaunchFusionError::ZeroBudget
416        );
417        assert_eq!(
418            plan_launch_fusion(
419                &[stage(1, 7, 1, 1, 1, false), stage(1, 7, 1, 1, 1, false),],
420                128,
421            )
422            .expect_err("duplicate stages should fail"),
423            LaunchFusionError::DuplicateStage { id: 1 }
424        );
425        assert_eq!(
426            plan_launch_fusion(&[stage(9, 7, 64, 32, 64, false)], 128)
427                .expect_err("single over-budget stage should fail"),
428            LaunchFusionError::StageOverBudget {
429                id: 9,
430                required_bytes: 160,
431                budget_bytes: 128,
432            }
433        );
434    }
435
436    #[test]
437    fn generated_launch_fusion_preserves_budget_and_order_contract() {
438        for seed in 0..4096_u64 {
439            let stages = generated_stages(seed);
440            let budget = 96 + (seed % 512);
441            let plan = plan_launch_fusion(&stages, budget)
442                .or_else(|error| match error {
443                    LaunchFusionError::StageOverBudget { .. } => Ok(LaunchFusionPlan {
444                        groups: Vec::new(),
445                        launch_count: 0,
446                        avoided_launches: 0,
447                        avoided_intermediate_bytes: 0,
448                    }),
449                    other => Err(other),
450                })
451                .expect(
452                    "Fix: generated launch fusion should only reject singleton over-budget stages",
453                );
454            if plan.groups.is_empty() {
455                continue;
456            }
457
458            let flattened = plan
459                .groups
460                .iter()
461                .flat_map(|group| group.stage_ids.iter().copied())
462                .collect::<Vec<_>>();
463            assert_eq!(
464                flattened,
465                stages.iter().map(|stage| stage.id).collect::<Vec<_>>(),
466                "Fix: launch fusion must preserve original stage order for seed {seed}."
467            );
468            assert_eq!(
469                usize::try_from(plan.launch_count).expect("Fix: plan launch_count must fit usize on this platform; reject oversized plans upstream - launch_count fits usize"),
470                plan.groups.len(),
471                "Fix: launch_count must match group count for seed {seed}."
472            );
473            assert_eq!(
474                usize::try_from(plan.avoided_launches).expect("Fix: avoided_launches must fit usize; clamp or reject plan before fusion stats - avoided_launches fits usize"),
475                stages.len() - plan.groups.len(),
476                "Fix: avoided_launches must match fused group reduction for seed {seed}."
477            );
478            for group in &plan.groups {
479                assert!(
480                    group.required_bytes <= budget,
481                    "Fix: fused group exceeded explicit budget for seed {seed}."
482                );
483            }
484        }
485    }
486
487    #[test]
488    fn launch_fusion_reuses_caller_owned_duplicate_detection_scratch() {
489        let mut scratch =
490            LaunchFusionScratch::try_with_capacity(64).expect("Fix: fusion scratch should reserve");
491        let wide = (0..64)
492            .map(|id| stage(id, 7, 16, 16, 4, false))
493            .collect::<Vec<_>>();
494        let first = plan_launch_fusion_with_scratch(&wide, 8_192, &mut scratch)
495            .expect("Fix: wide compatible stages should fuse");
496        let id_capacity = scratch.id_capacity();
497
498        assert_eq!(first.launch_count, 1);
499        assert_eq!(first.avoided_launches, 63);
500
501        let second = plan_launch_fusion_with_scratch(
502            &[
503                stage(10, 7, 64, 32, 8, false),
504                stage(11, 8, 32, 48, 8, false),
505            ],
506            512,
507            &mut scratch,
508        )
509        .expect("Fix: smaller incompatible stages should reuse duplicate-detection scratch");
510
511        assert_eq!(second.launch_count, 2);
512        assert!(scratch.id_capacity() >= id_capacity);
513    }
514
515    #[test]
516    fn launch_fusion_staging_reserves_fallibly() {
517        let src = include_str!("launch_fusion.rs");
518
519        assert!(
520            src.contains("LaunchFusionScratch::try_with_capacity(stages.len())?")
521                && src.contains("scratch.try_reserve_ids(stages.len())?")
522                && src.contains("ReservationPolicy")
523                && src.contains("StorageReserveFailed"),
524            "Fix: launch fusion staging must use shared fallible reservations under scale pressure."
525        );
526        assert!(
527            !src.contains(concat!("FxHashSet::with_capacity", "_and_hasher"))
528                && !src.contains(concat!("Vec::with_capacity", "(stages.len())"))
529                && !src.contains(concat!("groups: vec![", "group]"))
530                && !src.contains(concat!("stage_ids: vec![", "stage.id]"))
531                && !src.contains(concat!("scratch.ids", ".reserve(stages.len())")),
532            "Fix: launch fusion release planning must not use infallible staging allocation."
533        );
534    }
535
536    fn generated_stages(seed: u64) -> Vec<LaunchFusionStage> {
537        let count = 1 + (seed as usize % 24);
538        let mut stages = Vec::with_capacity(count);
539        let mut state = seed ^ 0xF051_1A4A_7E57_0001;
540        for index in 0..count {
541            stages.push(stage(
542                index as u32,
543                next_u64(&mut state) % 5,
544                1 + (next_u64(&mut state) % 48),
545                1 + (next_u64(&mut state) % 48),
546                next_u64(&mut state) % 24,
547                next_u64(&mut state) % 11 == 0,
548            ));
549        }
550        stages
551    }
552
553    fn stage(
554        id: u32,
555        layout_hash: u64,
556        input_bytes: u64,
557        output_bytes: u64,
558        scratch_bytes: u64,
559        requires_host_materialization: bool,
560    ) -> LaunchFusionStage {
561        LaunchFusionStage {
562            id,
563            layout_hash,
564            input_bytes,
565            output_bytes,
566            scratch_bytes,
567            requires_host_materialization,
568        }
569    }
570
571    fn next_u64(state: &mut u64) -> u64 {
572        let mut x = *state;
573        x ^= x << 13;
574        x ^= x >> 7;
575        x ^= x << 17;
576        *state = x;
577        x
578    }
579}