Skip to main content

vyre_driver_cuda/
megakernel_plan_cache.rs

1//! Bounded CUDA megakernel plan cache.
2//!
3//! The cache stores topology decisions keyed by stable graph layout,
4//! analysis family, CUDA device feature signature, and coarse runtime-pressure
5//! buckets. The first three fields are the architectural identity of a plan;
6//! pressure buckets prevent a sparse first query from poisoning dense later
7//! queries over the same resident graph.
8
9use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use rustc_hash::FxHashMap;
13
14use crate::backend::ordering::sort_unstable_by_key_if_needed;
15use crate::backend::staging_reserve::reserve_vec;
16use crate::device::CudaDeviceCaps;
17use crate::megakernel_scheduler::{
18    plan_cuda_megakernel_memory_budget, select_cuda_megakernel_topology,
19    select_cuda_megakernel_topology_stable, CudaMegakernelExecutionPlan, CudaMegakernelGraphShape,
20    CudaMegakernelMemoryBudget, CudaMegakernelMemoryError, CudaMegakernelScheduleSample,
21    CudaMegakernelTopology, CudaMegakernelTopologyDecision,
22};
23
24const DEFAULT_MAX_MEGAKERNEL_PLANS: usize = 256;
25const PRESSURE_BUCKET_BPS: u32 = 1_000;
26const DENSITY_BUCKETS: u16 = 16;
27const READBACK_BUCKET_SHIFT: u32 = 12;
28
29/// Analysis family for a cached CUDA megakernel plan.
30#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
31pub enum CudaMegakernelAnalysisKind {
32    /// Generic graph dataflow wave.
33    Dataflow,
34    /// IFDS/IDE-style exploded-supergraph propagation.
35    Ifds,
36    /// Reaching-definitions propagation.
37    ReachingDefinitions,
38    /// Live-variable propagation.
39    Liveness,
40    /// Points-to propagation.
41    PointsTo,
42    /// Source-token or parser-frontier wave.
43    ParserFrontend,
44    /// Caller-owned analysis family identified by a stable numeric tag.
45    Custom(u64),
46}
47
48/// CUDA device feature signature that invalidates cached megakernel plans.
49#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
50pub struct CudaMegakernelDeviceKey {
51    /// CUDA SM major version.
52    pub sm_major: u16,
53    /// CUDA SM minor version.
54    pub sm_minor: u16,
55    /// Hardware warp size.
56    pub warp_size: u16,
57    /// Whether cooperative grid synchronization is available.
58    pub supports_grid_sync: bool,
59    /// Whether tensor-core lowering is available for this backend session.
60    pub supports_tensor_cores: bool,
61    /// Maximum threads accepted for one workgroup/block.
62    pub max_workgroup_size: u32,
63}
64
65impl From<&CudaDeviceCaps> for CudaMegakernelDeviceKey {
66    fn from(caps: &CudaDeviceCaps) -> Self {
67        Self {
68            sm_major: caps.compute_capability.0.min(u32::from(u16::MAX)) as u16,
69            sm_minor: caps.compute_capability.1.min(u32::from(u16::MAX)) as u16,
70            warp_size: caps.required_warp_size_u32().min(u32::from(u16::MAX)) as u16,
71            supports_grid_sync: caps.compute_capability >= (6, 0) && caps.cooperative_launch,
72            supports_tensor_cores: caps.hardware_supports_tensor_cores(),
73            max_workgroup_size: caps.max_threads_per_block_u32(),
74        }
75    }
76}
77
78/// Stable key for cached CUDA megakernel plans.
79#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
80pub struct CudaMegakernelPlanCacheKey {
81    /// Stable hash of the normalized resident graph layout.
82    pub graph_layout_hash: u64,
83    /// Analysis family consuming the graph layout.
84    pub analysis_kind: CudaMegakernelAnalysisKind,
85    /// CUDA device feature signature.
86    pub device: CudaMegakernelDeviceKey,
87    /// Coarse active-frontier density bucket.
88    pub frontier_density_bucket: u16,
89    /// Coarse memory-pressure bucket in basis points.
90    pub memory_pressure_bucket: u32,
91    /// Coarse output/readback pressure bucket.
92    pub readback_pressure_bucket: u16,
93    /// Coarse launch-over-dispatch pressure bucket in basis points.
94    pub launch_pressure_bucket: u32,
95    /// Coarse caller-provided fusion-pressure bucket.
96    pub fusion_pressure_bucket: u32,
97}
98
99impl CudaMegakernelPlanCacheKey {
100    /// Build a cache key from stable identity fields and runtime pressure.
101    #[must_use]
102    pub fn new(
103        graph_layout_hash: u64,
104        analysis_kind: CudaMegakernelAnalysisKind,
105        device: CudaMegakernelDeviceKey,
106        frontier_density: f64,
107        memory_pressure_bps: u32,
108        readback_bytes: u64,
109        launch_pressure_bps: u32,
110        fusion_pressure: f64,
111    ) -> Self {
112        Self {
113            graph_layout_hash,
114            analysis_kind,
115            device,
116            frontier_density_bucket: density_bucket(frontier_density),
117            memory_pressure_bucket: pressure_bucket(memory_pressure_bps),
118            readback_pressure_bucket: readback_bucket(readback_bytes),
119            launch_pressure_bucket: pressure_bucket(launch_pressure_bps),
120            fusion_pressure_bucket: fusion_bucket(fusion_pressure),
121        }
122    }
123
124    fn identity(self) -> CudaMegakernelPlanIdentityKey {
125        CudaMegakernelPlanIdentityKey {
126            graph_layout_hash: self.graph_layout_hash,
127            analysis_kind: self.analysis_kind,
128            device: self.device,
129        }
130    }
131}
132
133#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
134struct CudaMegakernelPlanIdentityKey {
135    graph_layout_hash: u64,
136    analysis_kind: CudaMegakernelAnalysisKind,
137    device: CudaMegakernelDeviceKey,
138}
139
140/// Cached CUDA megakernel plan.
141#[derive(Clone, Copy, Debug, PartialEq)]
142pub struct CudaMegakernelCachedPlan {
143    /// Selected topology for this key.
144    pub topology: CudaMegakernelTopology,
145    /// Full decision telemetry used when the plan was inserted.
146    pub decision: CudaMegakernelTopologyDecision,
147}
148
149/// Runtime counters for [`CudaMegakernelPlanCache`].
150#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
151pub struct CudaMegakernelPlanCacheStats {
152    /// Cache lookup hits.
153    pub hits: u64,
154    /// Cache lookup misses.
155    pub misses: u64,
156    /// Entries evicted by the bounded LRU policy.
157    pub evictions: u64,
158    /// Current entry count.
159    pub entries: usize,
160}
161
162#[derive(Clone, Copy, Debug)]
163struct CudaMegakernelPlanCacheEntry {
164    plan: CudaMegakernelCachedPlan,
165    last_seen: u64,
166}
167
168/// Bounded LRU cache for CUDA megakernel topology plans.
169#[derive(Debug)]
170pub struct CudaMegakernelPlanCache {
171    entries: FxHashMap<CudaMegakernelPlanCacheKey, CudaMegakernelPlanCacheEntry>,
172    latest_by_identity: FxHashMap<CudaMegakernelPlanIdentityKey, (u64, CudaMegakernelTopology)>,
173    eviction_queue: BinaryHeap<Reverse<(u64, CudaMegakernelPlanCacheKey)>>,
174    max_entries: usize,
175    serial: u64,
176    hits: u64,
177    misses: u64,
178    evictions: u64,
179}
180
181fn increment_plan_cache_counter(counter: &mut u64, field: &'static str) {
182    vyre_driver::accounting::pinning_increment_u64(counter, || {
183        tracing::error!(
184            "CUDA megakernel {field} overflowed u64; pinning counter at u64::MAX. Fix: scrape metrics more frequently or shard the cache."
185        );
186    });
187}
188
189impl Default for CudaMegakernelPlanCache {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl CudaMegakernelPlanCache {
196    /// Create a cache with the default production entry bound.
197    #[must_use]
198    pub fn new() -> Self {
199        Self::with_max_entries(DEFAULT_MAX_MEGAKERNEL_PLANS)
200    }
201
202    /// Create a cache with an explicit entry bound.
203    #[must_use]
204    pub fn with_max_entries(max_entries: usize) -> Self {
205        Self {
206            entries: FxHashMap::default(),
207            latest_by_identity: FxHashMap::default(),
208            eviction_queue: BinaryHeap::new(),
209            max_entries,
210            serial: 0,
211            hits: 0,
212            misses: 0,
213            evictions: 0,
214        }
215    }
216
217    /// Return a cached plan or insert a newly selected topology decision.
218    pub fn get_or_insert_with(
219        &mut self,
220        key: CudaMegakernelPlanCacheKey,
221        build: impl FnOnce() -> CudaMegakernelTopologyDecision,
222    ) -> Result<CudaMegakernelCachedPlan, CudaMegakernelMemoryError> {
223        let serial = self.advance_serial()?;
224        if let Some(entry) = self.entries.get_mut(&key) {
225            increment_plan_cache_counter(&mut self.hits, "megakernel plan-cache hit counter");
226            entry.last_seen = serial;
227            let plan = entry.plan;
228            self.eviction_queue.push(Reverse((serial, key)));
229            self.update_latest_identity(key.identity(), serial, plan.topology);
230            return Ok(plan);
231        }
232        increment_plan_cache_counter(&mut self.misses, "megakernel plan-cache miss counter");
233        if self.max_entries == 0 {
234            let decision = build();
235            return Ok(CudaMegakernelCachedPlan {
236                topology: decision.topology,
237                decision,
238            });
239        }
240        self.evict_until_below_limit()?;
241        let decision = build();
242        let plan = CudaMegakernelCachedPlan {
243            topology: decision.topology,
244            decision,
245        };
246        self.entries.insert(
247            key,
248            CudaMegakernelPlanCacheEntry {
249                plan,
250                last_seen: serial,
251            },
252        );
253        self.eviction_queue.push(Reverse((serial, key)));
254        self.update_latest_identity(key.identity(), serial, plan.topology);
255        Ok(plan)
256    }
257
258    /// Return a cached topology plan or select and cache one from the current
259    /// CUDA telemetry sample.
260    ///
261    /// This is the hot-path convenience API: callers provide stable graph,
262    /// analysis, device, and telemetry inputs, while the cache owns the
263    /// pressure bucketing needed to avoid stale sparse/dense decisions.
264    pub fn get_or_select_topology(
265        &mut self,
266        graph_layout_hash: u64,
267        analysis_kind: CudaMegakernelAnalysisKind,
268        device: CudaMegakernelDeviceKey,
269        sample: CudaMegakernelScheduleSample,
270        graph: CudaMegakernelGraphShape,
271        memory: CudaMegakernelMemoryBudget,
272        launch_overhead_ns: f64,
273        fusion_pressure: f64,
274    ) -> Result<CudaMegakernelCachedPlan, CudaMegakernelMemoryError> {
275        let effective_fusion_pressure = if device.supports_grid_sync {
276            fusion_pressure
277        } else {
278            0.0
279        };
280        let key = CudaMegakernelPlanCacheKey::new(
281            graph_layout_hash,
282            analysis_kind,
283            device,
284            sample.frontier_density,
285            pressure_bps(memory.required_bytes, memory.budget_bytes),
286            sample.readback_bytes,
287            launch_pressure_bps(sample.dispatch_cost_ns, launch_overhead_ns),
288            effective_fusion_pressure,
289        );
290        let previous_topology =
291            self.latest_topology_for_identity(graph_layout_hash, analysis_kind, device);
292        self.get_or_insert_with(key, || {
293            if let Some(previous_topology) = previous_topology {
294                select_cuda_megakernel_topology_stable(
295                    sample,
296                    graph,
297                    memory,
298                    launch_overhead_ns,
299                    effective_fusion_pressure,
300                    previous_topology,
301                )
302            } else {
303                select_cuda_megakernel_topology(
304                    sample,
305                    graph,
306                    memory,
307                    launch_overhead_ns,
308                    effective_fusion_pressure,
309                )
310            }
311        })
312    }
313
314    /// Return a cache-backed, memory-validated CUDA megakernel execution plan.
315    ///
316    /// The cache key uses sparse-plan memory pressure because sparse is the
317    /// lower-bound resident footprint shared by every topology. A cache hit
318    /// reuses the prior topology decision, then this method validates the exact
319    /// current dense/fused/sparse byte budget before returning a launchable
320    /// plan. If the cached non-sparse topology no longer fits, the method
321    /// downgrades to sparse only after proving the sparse plan fits.
322    pub fn get_or_plan_execution(
323        &mut self,
324        graph_layout_hash: u64,
325        analysis_kind: CudaMegakernelAnalysisKind,
326        device: CudaMegakernelDeviceKey,
327        sample: CudaMegakernelScheduleSample,
328        graph: CudaMegakernelGraphShape,
329        bytes_per_node: u64,
330        bytes_per_edge: u64,
331        frontier_bytes: u64,
332        scratch_bytes: u64,
333        output_bytes: u64,
334        budget_bytes: u64,
335        launch_overhead_ns: f64,
336        fusion_pressure: f64,
337    ) -> Result<CudaMegakernelExecutionPlan, CudaMegakernelMemoryError> {
338        let sparse_memory = plan_cuda_megakernel_memory_budget(
339            CudaMegakernelTopology::SparseFrontier,
340            graph,
341            bytes_per_node,
342            bytes_per_edge,
343            frontier_bytes,
344            scratch_bytes,
345            output_bytes,
346            u64::MAX,
347        )?;
348        let cached = self.get_or_select_topology(
349            graph_layout_hash,
350            analysis_kind,
351            device,
352            sample,
353            graph,
354            CudaMegakernelMemoryBudget {
355                required_bytes: sparse_memory.required_bytes,
356                budget_bytes,
357            },
358            launch_overhead_ns,
359            fusion_pressure,
360        )?;
361        match plan_cuda_megakernel_memory_budget(
362            cached.topology,
363            graph,
364            bytes_per_node,
365            bytes_per_edge,
366            frontier_bytes,
367            scratch_bytes,
368            output_bytes,
369            budget_bytes,
370        ) {
371            Ok(memory) => Ok(CudaMegakernelExecutionPlan {
372                topology: cached.topology,
373                memory,
374                downgraded_to_sparse: false,
375            }),
376            Err(CudaMegakernelMemoryError::OverBudget { .. })
377                if cached.topology != CudaMegakernelTopology::SparseFrontier =>
378            {
379                let memory = plan_cuda_megakernel_memory_budget(
380                    CudaMegakernelTopology::SparseFrontier,
381                    graph,
382                    bytes_per_node,
383                    bytes_per_edge,
384                    frontier_bytes,
385                    scratch_bytes,
386                    output_bytes,
387                    budget_bytes,
388                )?;
389                Ok(CudaMegakernelExecutionPlan {
390                    topology: CudaMegakernelTopology::SparseFrontier,
391                    memory,
392                    downgraded_to_sparse: true,
393                })
394            }
395            Err(error) => Err(error),
396        }
397    }
398
399    /// Return cache counters.
400    #[must_use]
401    pub fn stats(&self) -> CudaMegakernelPlanCacheStats {
402        CudaMegakernelPlanCacheStats {
403            hits: self.hits,
404            misses: self.misses,
405            evictions: self.evictions,
406            entries: self.entries.len(),
407        }
408    }
409
410    /// Drop every cached plan and preserve counters for observability.
411    pub fn clear(&mut self) {
412        self.entries.clear();
413        self.latest_by_identity.clear();
414        self.eviction_queue.clear();
415    }
416
417    fn latest_topology_for_identity(
418        &self,
419        graph_layout_hash: u64,
420        analysis_kind: CudaMegakernelAnalysisKind,
421        device: CudaMegakernelDeviceKey,
422    ) -> Option<CudaMegakernelTopology> {
423        self.latest_by_identity
424            .get(&CudaMegakernelPlanIdentityKey {
425                graph_layout_hash,
426                analysis_kind,
427                device,
428            })
429            .map(|(_, topology)| *topology)
430    }
431
432    fn update_latest_identity(
433        &mut self,
434        identity: CudaMegakernelPlanIdentityKey,
435        serial: u64,
436        topology: CudaMegakernelTopology,
437    ) {
438        match self.latest_by_identity.get(&identity) {
439            Some((latest_serial, _)) if *latest_serial > serial => {}
440            _ => {
441                self.latest_by_identity.insert(identity, (serial, topology));
442            }
443        }
444    }
445
446    fn recompute_latest_identity(&mut self, identity: CudaMegakernelPlanIdentityKey) {
447        let latest = self
448            .entries
449            .iter()
450            .filter(|(key, _)| key.identity() == identity)
451            .max_by_key(|(_, entry)| entry.last_seen)
452            .map(|(_, entry)| (entry.last_seen, entry.plan.topology));
453        if let Some(latest) = latest {
454            self.latest_by_identity.insert(identity, latest);
455        } else {
456            self.latest_by_identity.remove(&identity);
457        }
458    }
459
460    fn evict_until_below_limit(&mut self) -> Result<(), CudaMegakernelMemoryError> {
461        while self.entries.len() >= self.max_entries {
462            let Some(Reverse((last_seen, lru_key))) = self.eviction_queue.pop() else {
463                break;
464            };
465            let Some(entry) = self.entries.get(&lru_key) else {
466                continue;
467            };
468            if entry.last_seen != last_seen {
469                continue;
470            }
471            let identity = lru_key.identity();
472            let evicted_topology = entry.plan.topology;
473            self.entries.remove(&lru_key);
474            if matches!(
475                self.latest_by_identity.get(&identity),
476                Some((latest_seen, latest_topology))
477                    if *latest_seen == last_seen && *latest_topology == evicted_topology
478            ) {
479                self.recompute_latest_identity(identity);
480            }
481            increment_plan_cache_counter(
482                &mut self.evictions,
483                "megakernel plan-cache eviction counter",
484            );
485        }
486        Ok(())
487    }
488
489    fn advance_serial(&mut self) -> Result<u64, CudaMegakernelMemoryError> {
490        if let Some(next) = self.serial.checked_add(1) {
491            self.serial = next;
492            return Ok(next);
493        }
494        self.rebase_lru_serials()?;
495        self.serial =
496            self.serial
497                .checked_add(1)
498                .ok_or(CudaMegakernelMemoryError::ByteCountOverflow {
499                    field: "megakernel plan-cache LRU serial after rebase",
500                })?;
501        Ok(self.serial)
502    }
503
504    fn rebase_lru_serials(&mut self) -> Result<(), CudaMegakernelMemoryError> {
505        let mut ordered = Vec::new();
506        reserve_vec(
507            &mut ordered,
508            self.entries.len(),
509            "megakernel plan-cache LRU rebase scratch",
510        )
511        .map_err(|_| CudaMegakernelMemoryError::ByteCountOverflow {
512            field: "megakernel plan-cache LRU rebase scratch",
513        })?;
514        for (key, entry) in &self.entries {
515            ordered.push((entry.last_seen, *key));
516        }
517        sort_unstable_by_key_if_needed(&mut ordered, |(last_seen, key)| (*last_seen, *key));
518        self.eviction_queue.clear();
519        self.latest_by_identity.clear();
520        let mut serial = 0_u64;
521        for (_, key) in ordered {
522            serial = serial
523                .checked_add(1)
524                .ok_or(CudaMegakernelMemoryError::ByteCountOverflow {
525                    field: "megakernel plan-cache LRU rebase serial",
526                })?;
527            let topology = if let Some(entry) = self.entries.get_mut(&key) {
528                entry.last_seen = serial;
529                Some(entry.plan.topology)
530            } else {
531                None
532            };
533            if let Some(topology) = topology {
534                self.eviction_queue.push(Reverse((serial, key)));
535                self.update_latest_identity(key.identity(), serial, topology);
536            }
537        }
538        self.serial = serial;
539        Ok(())
540    }
541}
542
543
544fn density_bucket(frontier_density: f64) -> u16 {
545    if !frontier_density.is_finite() {
546        return 0;
547    }
548    let clamped = frontier_density.clamp(0.0, 1.0);
549    rounded_f64_to_u16_bucket(
550        clamped * f64::from(DENSITY_BUCKETS - 1),
551        "frontier-density bucket",
552    )
553}
554
555fn pressure_bucket(memory_pressure_bps: u32) -> u32 {
556    memory_pressure_bps / PRESSURE_BUCKET_BPS
557}
558
559fn pressure_bps(numerator: u64, denominator: u64) -> u32 {
560    crate::numeric::CUDA_NUMERIC.ratio_basis_points_u64(
561        numerator,
562        denominator,
563        if numerator == 0 { 0 } else { u32::MAX },
564        "megakernel pressure",
565    )
566}
567
568fn launch_pressure_bps(dispatch_cost_ns: f64, launch_overhead_ns: f64) -> u32 {
569    crate::numeric::CUDA_NUMERIC.finite_f64_ratio_basis_points_trunc(
570        launch_overhead_ns,
571        dispatch_cost_ns,
572        u32::MAX,
573        0,
574        "launch-pressure basis-points",
575    )
576}
577
578fn readback_bucket(readback_bytes: u64) -> u16 {
579    if readback_bytes == 0 {
580        return 0;
581    }
582    let shifted = readback_bytes >> READBACK_BUCKET_SHIFT;
583    let bucket = u64::BITS - shifted.leading_zeros();
584    bucket.min(u32::from(u16::MAX)) as u16
585}
586
587fn fusion_bucket(fusion_pressure: f64) -> u32 {
588    pressure_bucket(
589        crate::numeric::CUDA_NUMERIC.finite_f64_unit_basis_points_trunc(
590            fusion_pressure,
591            0,
592            "fusion-pressure basis-points",
593        ),
594    )
595}
596
597fn rounded_f64_to_u16_bucket(value: f64, label: &'static str) -> u16 {
598    let rounded = value.round();
599    if !rounded.is_finite() || rounded < 0.0 || rounded > f64::from(u16::MAX) {
600        tracing::error!(
601            "CUDA megakernel {label} value {rounded} cannot fit u16. Fix: reduce bucket resolution or shard cache domains."
602        );
603        return if rounded.is_sign_negative() {
604            0
605        } else {
606            u16::MAX
607        };
608    }
609    rounded as u16
610}
611
612#[cfg(test)]
613mod tests {
614    use super::{
615        CudaMegakernelAnalysisKind, CudaMegakernelDeviceKey, CudaMegakernelPlanCache,
616        CudaMegakernelPlanCacheKey,
617    };
618    use crate::megakernel_scheduler::{
619        CudaMegakernelGraphShape, CudaMegakernelScheduleSample, CudaMegakernelTopology,
620        CudaMegakernelTopologyDecision,
621    };
622    use crate::synthetic_device_caps::blackwell_sm120_caps_default;
623
624    fn device() -> CudaMegakernelDeviceKey {
625        CudaMegakernelDeviceKey {
626            sm_major: 12,
627            sm_minor: 0,
628            warp_size: 32,
629            supports_grid_sync: true,
630            supports_tensor_cores: true,
631            max_workgroup_size: 1024,
632        }
633    }
634
635    fn key(
636        graph_layout_hash: u64,
637        analysis_kind: CudaMegakernelAnalysisKind,
638        frontier_density: f64,
639        memory_pressure_bps: u32,
640    ) -> CudaMegakernelPlanCacheKey {
641        CudaMegakernelPlanCacheKey::new(
642            graph_layout_hash,
643            analysis_kind,
644            device(),
645            frontier_density,
646            memory_pressure_bps,
647            0,
648            0,
649            0.0,
650        )
651    }
652
653    fn decision(topology: CudaMegakernelTopology) -> CudaMegakernelTopologyDecision {
654        CudaMegakernelTopologyDecision {
655            topology,
656            memory_pressure_bps: 1_000,
657            average_degree_bps: 20_000,
658            launch_pressure_bps: 2_000,
659        }
660    }
661
662    #[test]
663    fn cache_reuses_plan_for_same_graph_analysis_device_and_pressure_bucket() {
664        let mut cache = CudaMegakernelPlanCache::new();
665        let key = key(42, CudaMegakernelAnalysisKind::Ifds, 0.52, 2_400);
666        let first = cache
667            .get_or_insert_with(key, || decision(CudaMegakernelTopology::FusedWave))
668            .expect("Fix: CUDA megakernel plan-cache insert should fit telemetry counters.");
669        let second = cache
670            .get_or_insert_with(key, || decision(CudaMegakernelTopology::SparseFrontier))
671            .expect("Fix: CUDA megakernel plan-cache hit should fit telemetry counters.");
672
673        assert_eq!(first, second);
674        assert_eq!(second.topology, CudaMegakernelTopology::FusedWave);
675        let stats = cache.stats();
676        assert_eq!(stats.hits, 1);
677        assert_eq!(stats.misses, 1);
678        assert_eq!(stats.entries, 1);
679    }
680
681    #[test]
682    fn device_key_is_derived_from_cuda_caps() {
683        assert_eq!(
684            CudaMegakernelDeviceKey::from(&blackwell_sm120_caps_default()),
685            device()
686        );
687    }
688
689    #[test]
690    fn cache_separates_analysis_family_density_and_device_features() {
691        let ifds = key(42, CudaMegakernelAnalysisKind::Ifds, 0.01, 1_000);
692        let liveness = key(42, CudaMegakernelAnalysisKind::Liveness, 0.01, 1_000);
693        let dense = key(42, CudaMegakernelAnalysisKind::Ifds, 0.95, 1_000);
694        let mut other_device = device();
695        other_device.sm_minor = 1;
696        let device_changed = CudaMegakernelPlanCacheKey::new(
697            42,
698            CudaMegakernelAnalysisKind::Ifds,
699            other_device,
700            0.01,
701            1_000,
702            0,
703            0,
704            0.0,
705        );
706
707        assert_ne!(ifds, liveness);
708        assert_ne!(ifds, dense);
709        assert_ne!(ifds, device_changed);
710    }
711
712    #[test]
713    fn bounded_cache_evicts_lru_entry() {
714        let mut cache = CudaMegakernelPlanCache::with_max_entries(2);
715        let first = key(1, CudaMegakernelAnalysisKind::Dataflow, 0.1, 1_000);
716        let second = key(2, CudaMegakernelAnalysisKind::Dataflow, 0.1, 1_000);
717        let third = key(3, CudaMegakernelAnalysisKind::Dataflow, 0.1, 1_000);
718
719        cache
720            .get_or_insert_with(first, || decision(CudaMegakernelTopology::SparseFrontier))
721            .expect("Fix: CUDA megakernel plan-cache insert should fit telemetry counters.");
722        cache
723            .get_or_insert_with(second, || decision(CudaMegakernelTopology::HybridFrontier))
724            .expect("Fix: CUDA megakernel plan-cache insert should fit telemetry counters.");
725        cache
726            .get_or_insert_with(first, || decision(CudaMegakernelTopology::DenseFrontier))
727            .expect("Fix: CUDA megakernel plan-cache hit should fit telemetry counters.");
728        cache
729            .get_or_insert_with(third, || decision(CudaMegakernelTopology::FusedWave))
730            .expect("Fix: CUDA megakernel plan-cache eviction should fit telemetry counters.");
731
732        let stats = cache.stats();
733        assert_eq!(stats.hits, 1);
734        assert_eq!(stats.misses, 3);
735        assert_eq!(stats.evictions, 1);
736        assert_eq!(stats.entries, 2);
737        let reloaded_second = cache
738            .get_or_insert_with(second, || decision(CudaMegakernelTopology::DenseFrontier))
739            .expect("Fix: CUDA megakernel plan-cache reload should fit telemetry counters.");
740        assert_eq!(
741            reloaded_second.topology,
742            CudaMegakernelTopology::DenseFrontier
743        );
744    }
745
746    #[test]
747    fn cache_selects_topology_and_reuses_pressure_bucket_plan() {
748        let mut cache = CudaMegakernelPlanCache::new();
749        let sample = crate::megakernel_scheduler::CudaMegakernelScheduleSample {
750            dispatch_cost_ns: 1_000.0,
751            frontier_density: 0.90,
752            readback_bytes: 1 << 20,
753        };
754        let graph = crate::megakernel_scheduler::CudaMegakernelGraphShape {
755            node_count: 1_000,
756            edge_count: 4_000,
757        };
758        let memory = crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
759            required_bytes: 1_024,
760            budget_bytes: 16_384,
761        };
762        let first = cache
763            .get_or_select_topology(
764                99,
765                CudaMegakernelAnalysisKind::Dataflow,
766                device(),
767                sample,
768                graph,
769                memory,
770                250.0,
771                0.95,
772            )
773            .expect("Fix: CUDA megakernel topology selection should fit telemetry counters.");
774        let second = cache
775            .get_or_select_topology(
776                99,
777                CudaMegakernelAnalysisKind::Dataflow,
778                device(),
779                crate::megakernel_scheduler::CudaMegakernelScheduleSample {
780                    frontier_density: 0.91,
781                    ..sample
782                },
783                graph,
784                crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
785                    required_bytes: 1_100,
786                    budget_bytes: 16_384,
787                },
788                250.0,
789                0.95,
790            )
791            .expect("Fix: CUDA megakernel topology cache hit should fit telemetry counters.");
792
793        assert_eq!(first, second);
794        assert_eq!(first.topology, CudaMegakernelTopology::FusedWave);
795        assert_eq!(cache.stats().hits, 1);
796        assert_eq!(cache.stats().misses, 1);
797    }
798
799    #[test]
800    fn cache_stabilizes_topology_across_adjacent_pressure_buckets() {
801        let mut cache = CudaMegakernelPlanCache::new();
802        let graph = crate::megakernel_scheduler::CudaMegakernelGraphShape {
803            node_count: 1_000,
804            edge_count: 4_000,
805        };
806        let memory = crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
807            required_bytes: 1_024,
808            budget_bytes: 16_384,
809        };
810        let dense = cache
811            .get_or_select_topology(
812                99,
813                CudaMegakernelAnalysisKind::Dataflow,
814                device(),
815                crate::megakernel_scheduler::CudaMegakernelScheduleSample {
816                    dispatch_cost_ns: 1_000.0,
817                    frontier_density: 0.70,
818                    readback_bytes: 512,
819                },
820                graph,
821                memory,
822                100.0,
823                0.0,
824            )
825            .expect("Fix: CUDA megakernel topology selection should fit telemetry counters.");
826        let near_dense = cache
827            .get_or_select_topology(
828                99,
829                CudaMegakernelAnalysisKind::Dataflow,
830                device(),
831                crate::megakernel_scheduler::CudaMegakernelScheduleSample {
832                    dispatch_cost_ns: 1_000.0,
833                    frontier_density: 0.68,
834                    readback_bytes: 512,
835                },
836                graph,
837                memory,
838                100.0,
839                0.0,
840            )
841            .expect("Fix: CUDA megakernel topology stabilization should fit telemetry counters.");
842
843        assert_eq!(dense.topology, CudaMegakernelTopology::DenseFrontier);
844        assert_eq!(near_dense.topology, CudaMegakernelTopology::DenseFrontier);
845        assert_eq!(cache.stats().hits, 0);
846        assert_eq!(cache.stats().misses, 2);
847    }
848
849    #[test]
850    fn cache_reselects_when_memory_pressure_bucket_changes() {
851        let mut cache = CudaMegakernelPlanCache::new();
852        let sample = crate::megakernel_scheduler::CudaMegakernelScheduleSample {
853            dispatch_cost_ns: 1_000.0,
854            frontier_density: 0.90,
855            readback_bytes: 1 << 20,
856        };
857        let graph = crate::megakernel_scheduler::CudaMegakernelGraphShape {
858            node_count: 1_000,
859            edge_count: 4_000,
860        };
861        let low_pressure = cache
862            .get_or_select_topology(
863                99,
864                CudaMegakernelAnalysisKind::Dataflow,
865                device(),
866                sample,
867                graph,
868                crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
869                    required_bytes: 1_024,
870                    budget_bytes: 16_384,
871                },
872                250.0,
873                0.95,
874            )
875            .expect("Fix: CUDA megakernel topology selection should fit telemetry counters.");
876        let red_zone = cache
877            .get_or_select_topology(
878                99,
879                CudaMegakernelAnalysisKind::Dataflow,
880                device(),
881                sample,
882                graph,
883                crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
884                    required_bytes: 15_500,
885                    budget_bytes: 16_384,
886                },
887                250.0,
888                0.95,
889            )
890            .expect("Fix: CUDA megakernel topology reselection should fit telemetry counters.");
891
892        assert_eq!(low_pressure.topology, CudaMegakernelTopology::FusedWave);
893        assert_eq!(red_zone.topology, CudaMegakernelTopology::SparseFrontier);
894        assert_eq!(cache.stats().hits, 0);
895        assert_eq!(cache.stats().misses, 2);
896    }
897
898    #[test]
899    fn cache_pressure_bucket_uses_exact_u128_math() {
900        let low = CudaMegakernelPlanCacheKey::new(
901            1,
902            CudaMegakernelAnalysisKind::Dataflow,
903            device(),
904            0.5,
905            super::pressure_bps(1_u64 << 62, 1_u64 << 63),
906            0,
907            0,
908            0.0,
909        );
910        let high = CudaMegakernelPlanCacheKey::new(
911            1,
912            CudaMegakernelAnalysisKind::Dataflow,
913            device(),
914            0.5,
915            super::pressure_bps(1_u64 << 63, 1_u64 << 63),
916            0,
917            0,
918            0.0,
919        );
920
921        assert_eq!(low.memory_pressure_bucket, 5);
922        assert_eq!(high.memory_pressure_bucket, 10);
923    }
924
925    #[test]
926    fn cache_reselects_when_readback_launch_or_fusion_pressure_changes() {
927        let mut cache = CudaMegakernelPlanCache::new();
928        let graph = CudaMegakernelGraphShape {
929            node_count: 1_000,
930            edge_count: 4_000,
931        };
932        let memory = crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
933            required_bytes: 1_024,
934            budget_bytes: 16_384,
935        };
936        let low_pressure = cache
937            .get_or_select_topology(
938                99,
939                CudaMegakernelAnalysisKind::Dataflow,
940                device(),
941                CudaMegakernelScheduleSample {
942                    dispatch_cost_ns: 1_000.0,
943                    frontier_density: 0.50,
944                    readback_bytes: 0,
945                },
946                graph,
947                memory,
948                250.0,
949                0.95,
950            )
951            .expect("Fix: CUDA megakernel topology selection should fit telemetry counters.");
952        let high_pressure = cache
953            .get_or_select_topology(
954                99,
955                CudaMegakernelAnalysisKind::Dataflow,
956                device(),
957                CudaMegakernelScheduleSample {
958                    dispatch_cost_ns: 1_000.0,
959                    frontier_density: 0.50,
960                    readback_bytes: 1 << 20,
961                },
962                graph,
963                memory,
964                250.0,
965                0.95,
966            )
967            .expect("Fix: CUDA megakernel topology pressure split should fit telemetry counters.");
968
969        assert_ne!(low_pressure.topology, CudaMegakernelTopology::FusedWave);
970        assert_eq!(high_pressure.topology, CudaMegakernelTopology::FusedWave);
971        assert_eq!(cache.stats().hits, 0);
972        assert_eq!(cache.stats().misses, 2);
973    }
974
975    #[test]
976    fn cache_never_selects_fused_wave_without_grid_sync_support() {
977        let mut cache = CudaMegakernelPlanCache::new();
978        let mut no_grid_sync = device();
979        no_grid_sync.supports_grid_sync = false;
980
981        let plan = cache
982            .get_or_select_topology(
983                99,
984                CudaMegakernelAnalysisKind::Dataflow,
985                no_grid_sync,
986                CudaMegakernelScheduleSample {
987                    dispatch_cost_ns: 1_000.0,
988                    frontier_density: 0.50,
989                    readback_bytes: 1 << 20,
990                },
991                CudaMegakernelGraphShape {
992                    node_count: 1_000,
993                    edge_count: 4_000,
994                },
995                crate::megakernel_scheduler::CudaMegakernelMemoryBudget {
996                    required_bytes: 1_024,
997                    budget_bytes: 16_384,
998                },
999                250.0,
1000                0.95,
1001            )
1002            .expect("Fix: CUDA megakernel topology selection should fit telemetry counters.");
1003
1004        assert_ne!(
1005            plan.topology,
1006            CudaMegakernelTopology::FusedWave,
1007            "Fix: CUDA megakernel planner must not select cooperative fused-wave topology when the device key says grid sync is unavailable."
1008        );
1009    }
1010
1011    #[test]
1012    fn cached_execution_plan_reuses_topology_bucket_and_validates_memory() {
1013        let mut cache = CudaMegakernelPlanCache::new();
1014        let sample = CudaMegakernelScheduleSample {
1015            dispatch_cost_ns: 1_000.0,
1016            frontier_density: 0.90,
1017            readback_bytes: 1 << 20,
1018        };
1019        let graph = CudaMegakernelGraphShape {
1020            node_count: 1_000,
1021            edge_count: 4_000,
1022        };
1023        let first = cache
1024            .get_or_plan_execution(
1025                99,
1026                CudaMegakernelAnalysisKind::Dataflow,
1027                device(),
1028                sample,
1029                graph,
1030                16,
1031                8,
1032                4_096,
1033                2_048,
1034                512,
1035                128 * 1024,
1036                250.0,
1037                0.95,
1038            )
1039            .expect("Fix: cache-backed fused CUDA execution plan should fit the explicit budget.");
1040        let second = cache
1041            .get_or_plan_execution(
1042                99,
1043                CudaMegakernelAnalysisKind::Dataflow,
1044                device(),
1045                CudaMegakernelScheduleSample {
1046                    frontier_density: 0.91,
1047                    ..sample
1048                },
1049                graph,
1050                16,
1051                8,
1052                4_096,
1053                2_048,
1054                512,
1055                128 * 1024,
1056                250.0,
1057                0.95,
1058            )
1059            .expect("Fix: equivalent CUDA execution pressure bucket should reuse the cached topology and still validate memory.");
1060
1061        assert_eq!(first.topology, CudaMegakernelTopology::FusedWave);
1062        assert_eq!(second.topology, CudaMegakernelTopology::FusedWave);
1063        assert_eq!(second.memory.scratch_bytes, 8_192);
1064        assert!(!second.downgraded_to_sparse);
1065        assert_eq!(cache.stats().hits, 1);
1066        assert_eq!(cache.stats().misses, 1);
1067    }
1068
1069    #[test]
1070    fn cached_execution_plan_downgrades_non_sparse_topology_when_exact_budget_fails() {
1071        let mut cache = CudaMegakernelPlanCache::new();
1072        let plan = cache
1073            .get_or_plan_execution(
1074                99,
1075                CudaMegakernelAnalysisKind::Dataflow,
1076                device(),
1077                CudaMegakernelScheduleSample {
1078                    dispatch_cost_ns: 1_000.0,
1079                    frontier_density: 0.50,
1080                    readback_bytes: 1 << 20,
1081                },
1082                CudaMegakernelGraphShape {
1083                    node_count: 1_000,
1084                    edge_count: 4_000,
1085                },
1086                16,
1087                8,
1088                4_096,
1089                10_000,
1090                512,
1091                80_000,
1092                250.0,
1093                0.90,
1094            )
1095            .expect("Fix: sparse CUDA downgrade must fit after cached fused topology exceeds exact budget.");
1096
1097        assert_eq!(plan.topology, CudaMegakernelTopology::SparseFrontier);
1098        assert!(plan.downgraded_to_sparse);
1099        assert_eq!(plan.memory.scratch_bytes, 10_000);
1100        assert_eq!(cache.stats().misses, 1);
1101        assert_eq!(cache.stats().entries, 1);
1102    }
1103
1104    #[test]
1105    fn cache_rebases_lru_serial_instead_of_failing_dispatch() {
1106        let mut cache = CudaMegakernelPlanCache::with_max_entries(2);
1107        let first = key(1, CudaMegakernelAnalysisKind::Ifds, 0.10, 1_000);
1108        let second = key(2, CudaMegakernelAnalysisKind::Ifds, 0.20, 1_000);
1109        cache
1110            .get_or_insert_with(first, || decision(CudaMegakernelTopology::SparseFrontier))
1111            .expect("Fix: first plan insert should fit");
1112        cache
1113            .get_or_insert_with(second, || decision(CudaMegakernelTopology::DenseFrontier))
1114            .expect("Fix: second plan insert should fit");
1115        cache.serial = u64::MAX;
1116
1117        cache
1118            .get_or_insert_with(first, || decision(CudaMegakernelTopology::FusedWave))
1119            .expect(
1120                "Fix: LRU serial exhaustion must rebase instead of failing the CUDA dispatch path",
1121            );
1122
1123        let first_seen = cache
1124            .entries
1125            .get(&first)
1126            .expect("Fix: first entry must remain")
1127            .last_seen;
1128        let second_seen = cache
1129            .entries
1130            .get(&second)
1131            .expect("Fix: second entry must remain")
1132            .last_seen;
1133        assert!(first_seen > second_seen);
1134        assert_eq!(cache.stats().hits, 1);
1135    }
1136
1137    #[test]
1138    fn cache_counters_pin_instead_of_failing_dispatch() {
1139        let mut cache = CudaMegakernelPlanCache::new();
1140        let key = key(3, CudaMegakernelAnalysisKind::Ifds, 0.10, 1_000);
1141        cache
1142            .get_or_insert_with(key, || decision(CudaMegakernelTopology::SparseFrontier))
1143            .expect("Fix: plan insert should fit");
1144        cache.hits = u64::MAX;
1145
1146        cache
1147            .get_or_insert_with(key, || decision(CudaMegakernelTopology::DenseFrontier))
1148            .expect("Fix: counter exhaustion must not fail the CUDA dispatch path");
1149
1150        assert_eq!(cache.stats().hits, u64::MAX);
1151    }
1152
1153    #[test]
1154    fn cache_eviction_is_queue_backed_not_map_scanned() {
1155        let src = include_str!("megakernel_plan_cache.rs");
1156        assert!(
1157            !src.contains(concat!(".iter()", ".min_by_key")),
1158            "Fix: CUDA megakernel plan-cache eviction must use the ordered eviction queue, not scan every cached plan on cold insert."
1159        );
1160        assert!(
1161            src.contains("BinaryHeap<Reverse<(u64, CudaMegakernelPlanCacheKey)>>"),
1162            "Fix: CUDA megakernel plan cache must keep an ordered LRU queue for hot-path eviction."
1163        );
1164        assert!(
1165            src.contains("increment_plan_cache_counter")
1166                && !src.contains(concat!(".", "saturating_add")),
1167            "Fix: CUDA megakernel plan-cache telemetry counters must pin loudly on overflow without hiding it behind saturating_add."
1168        );
1169        assert!(
1170            !src.contains(concat!("panic!", "(\"Fix: CUDA megakernel plan-cache")),
1171            "Fix: CUDA megakernel plan-cache overflow must return typed planner errors instead of panicking."
1172        );
1173        let production = src
1174            .split("#[cfg(test)]")
1175            .next()
1176            .expect("Fix: megakernel plan-cache source must contain production section");
1177        assert!(
1178            !production.contains(concat!("panic", "!("))
1179                && !production.contains(".expect(")
1180                && !production.contains(".unwrap_or_else(")
1181                && !production.contains("assert!("),
1182            "Fix: CUDA megakernel plan-cache production pressure bucketing and accounting must not panic."
1183        );
1184        assert!(
1185            production.contains("pub memory_pressure_bucket: u32")
1186                && production.contains("pub launch_pressure_bucket: u32")
1187                && production.contains("pub fusion_pressure_bucket: u32")
1188                && production.contains("tracing::error!"),
1189            "Fix: CUDA megakernel plan-cache pressure buckets must be wide enough for release telemetry and overflow must remain observable."
1190        );
1191        assert!(
1192            !src.contains(concat!(".", "wrapping_add"))
1193                && src.contains("fn rebase_lru_serials")
1194                && src.contains("fn advance_serial"),
1195            "Fix: CUDA megakernel plan-cache LRU serial must rebase on overflow, not wrap or fail hot dispatch."
1196        );
1197        assert!(
1198            production.contains("use crate::backend::ordering::sort_unstable_by_key_if_needed;")
1199                && production.contains("sort_unstable_by_key_if_needed(&mut ordered"),
1200            "Fix: CUDA megakernel plan-cache LRU rebase must use the shared monotonic sort fast path instead of a bespoke unconditional sort."
1201        );
1202        assert!(
1203            !production.contains(".sort_unstable_by_key("),
1204            "Fix: CUDA megakernel plan-cache production code must not reintroduce unconditional key sorting."
1205        );
1206        let latest_lookup = production
1207            .split("fn latest_topology_for_identity")
1208            .nth(1)
1209            .expect("Fix: CUDA megakernel plan-cache must expose previous-topology lookup")
1210            .split("fn update_latest_identity")
1211            .next()
1212            .expect("Fix: CUDA megakernel plan-cache lookup function must be bounded");
1213        assert!(
1214            latest_lookup.contains("latest_by_identity")
1215                && latest_lookup.contains(".get(&CudaMegakernelPlanIdentityKey")
1216                && !latest_lookup.contains(".iter()"),
1217            "Fix: previous-topology lookup must use the identity index instead of scanning every cached plan on cache miss."
1218        );
1219    }
1220}
1221