Skip to main content

vyre_driver/
device_extraction.rs

1//! Device-conditioned e-graph extraction helpers.
2//!
3//! Equality saturation is substrate-neutral: one saturated e-graph can hold
4//! every proven equivalent representation of a computation. Extraction is the
5//! point where a concrete device should matter. This module keeps that
6//! device-conditioned choice in `vyre-driver`, so native, portable, secondary, and
7//! future backends share the same extraction contract instead of each lowering
8//! path inventing its own cost plumbing.
9
10use smallvec::SmallVec;
11use vyre_foundation::optimizer::eqsat::{extract_best, EClassId, EGraph, ENodeLang};
12
13use crate::autotune_store::AutotuneRecord;
14use crate::device_profile::DeviceProfile;
15use crate::extraction_cost::{device_aware_cost, NodeHints};
16use crate::trace_jit_policy::{decide_trace_jit_speculation, TraceJitDecision, TraceJitInputs};
17
18/// Device context used for one extraction from a saturated e-graph.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct ExtractionDevice<'a> {
21    /// Capability profile for the target backend/device.
22    pub profile: &'a DeviceProfile,
23    /// Last winning autotune record for the root shape on this device.
24    pub autotune_record: Option<&'a AutotuneRecord>,
25    /// Recent trace-JIT counters for the same shader family.
26    pub trace_jit: Option<TraceJitInputs>,
27    /// Whether the current root is known hot from runtime counters.
28    pub hot_path: bool,
29}
30
31impl<'a> ExtractionDevice<'a> {
32    /// Build an extraction context for `profile`.
33    #[must_use]
34    pub const fn new(profile: &'a DeviceProfile, hot_path: bool) -> Self {
35        Self {
36            profile,
37            autotune_record: None,
38            trace_jit: None,
39            hot_path,
40        }
41    }
42
43    /// Attach the last winning autotune record for this device/root.
44    #[must_use]
45    pub const fn with_autotune_record(mut self, record: &'a AutotuneRecord) -> Self {
46        self.autotune_record = Some(record);
47        self
48    }
49
50    /// Attach trace-JIT counters for the same shader family.
51    #[must_use]
52    pub const fn with_trace_jit(mut self, counters: TraceJitInputs) -> Self {
53        self.trace_jit = Some(counters);
54        self
55    }
56}
57
58/// Best equivalent e-node selected for one device profile.
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct DeviceExtraction<L> {
61    /// Backend id from the source [`DeviceProfile`].
62    pub backend: &'static str,
63    /// Whether hot-path scaling was applied.
64    pub hot_path: bool,
65    /// Selected e-node.
66    pub node: L,
67    /// Total extracted cost, including child-class costs.
68    pub cost: u64,
69}
70
71/// Extract the best equivalent representation for one device profile.
72#[must_use]
73pub fn extract_best_for_device<L, B, H>(
74    egraph: &EGraph<L>,
75    root: EClassId,
76    device: ExtractionDevice<'_>,
77    base_cost: B,
78    hint_lookup: H,
79) -> Option<DeviceExtraction<L>>
80where
81    L: ENodeLang,
82    B: Fn(&L) -> u64,
83    H: Fn(&L) -> NodeHints,
84{
85    if root.0 as usize >= egraph.class_count() {
86        return None;
87    }
88    let profile_cost = device_aware_cost(device.profile, device.hot_path, &base_cost, &hint_lookup);
89    let cost = |node: &L| {
90        let hints = hint_lookup(node);
91        let cost = profile_cost(node);
92        apply_context_bias(cost, extraction_bias_bps(device, hints))
93    };
94    extract_best(egraph, root, cost).map(|(node, cost)| DeviceExtraction {
95        backend: device.profile.backend,
96        hot_path: device.hot_path,
97        node,
98        cost,
99    })
100}
101
102/// Extract best variants for several devices from the same saturated e-graph.
103///
104/// The e-graph is not rebuilt or re-saturated between devices; only the
105/// extractor cost closure changes. This is the shared substrate needed for
106/// "same saturated graph, native-optimal and portable-optimal variants" workflows.
107#[must_use]
108pub fn extract_best_for_devices<'a, L, B, H>(
109    egraph: &EGraph<L>,
110    root: EClassId,
111    devices: impl IntoIterator<Item = ExtractionDevice<'a>>,
112    base_cost: B,
113    hint_lookup: H,
114) -> SmallVec<[DeviceExtraction<L>; 4]>
115where
116    L: ENodeLang,
117    B: Fn(&L) -> u64,
118    H: Fn(&L) -> NodeHints,
119{
120    let mut out = SmallVec::new();
121    for device in devices {
122        if let Some(extracted) =
123            extract_best_for_device(egraph, root, device, &base_cost, &hint_lookup)
124        {
125            out.push(extracted);
126        }
127    }
128    out
129}
130
131fn extraction_bias_bps(device: ExtractionDevice<'_>, hints: NodeHints) -> u32 {
132    let mut bps = 10_000u32;
133    if let Some(record) = device.autotune_record {
134        if hints.compile_time_constant && record.unroll > 1 {
135            bps = scale_bps(bps, 8_000);
136        }
137        if hints.fp16_eligible && record.tile.iter().any(|dim| *dim > 1) {
138            bps = scale_bps(bps, 9_500);
139        }
140    }
141    if hints.compile_time_constant {
142        if let Some(counters) = device.trace_jit {
143            if matches!(
144                decide_trace_jit_speculation(counters),
145                TraceJitDecision::Speculate { .. }
146            ) {
147                bps = scale_bps(bps, 7_000);
148            }
149        }
150    }
151    bps.max(1)
152}
153
154fn scale_bps(lhs_bps: u32, rhs_bps: u32) -> u32 {
155    crate::numeric::compose_basis_points_u32(
156        lhs_bps,
157        rhs_bps,
158        "device extraction bias composition",
159        "driver",
160    )
161}
162
163fn apply_context_bias(cost: u64, bps: u32) -> u64 {
164    if bps >= 10_000 {
165        return cost;
166    }
167    crate::numeric::scale_u64_by_basis_points_floor_min(
168        cost,
169        bps,
170        1,
171        "device extraction context bias",
172        "driver",
173    )
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use vyre_foundation::optimizer::eqsat::{EChildren, EGraph};
180
181    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
182    enum Toy {
183        Scalar,
184        TensorCore,
185        Specialized,
186    }
187
188    impl ENodeLang for Toy {
189        fn children(&self) -> EChildren {
190            EChildren::new()
191        }
192
193        fn with_children(&self, _children: &[EClassId]) -> Self {
194            self.clone()
195        }
196    }
197
198    fn base_cost(node: &Toy) -> u64 {
199        match node {
200            Toy::Scalar => 10,
201            Toy::TensorCore => 30,
202            Toy::Specialized => 11,
203        }
204    }
205
206    fn hints(node: &Toy) -> NodeHints {
207        match node {
208            Toy::TensorCore => NodeHints {
209                fp16_eligible: true,
210                compile_time_constant: false,
211            },
212            Toy::Specialized => NodeHints {
213                fp16_eligible: false,
214                compile_time_constant: true,
215            },
216            Toy::Scalar => NodeHints::default(),
217        }
218    }
219
220    fn equivalent_toy_graph() -> (EGraph<Toy>, EClassId) {
221        let mut graph = EGraph::new();
222        let scalar = graph.add(Toy::Scalar);
223        let tensor = graph.add(Toy::TensorCore);
224        graph.union(scalar, tensor);
225        graph.rebuild();
226        (graph, scalar)
227    }
228
229    fn specialized_toy_graph() -> (EGraph<Toy>, EClassId) {
230        let mut graph = EGraph::new();
231        let scalar = graph.add(Toy::Scalar);
232        let specialized = graph.add(Toy::Specialized);
233        graph.union(scalar, specialized);
234        graph.rebuild();
235        (graph, scalar)
236    }
237
238    #[test]
239    fn conservative_profile_extracts_scalar_variant() {
240        let (graph, root) = equivalent_toy_graph();
241        let profile = DeviceProfile::conservative("portable");
242        let extracted = extract_best_for_device(
243            &graph,
244            root,
245            ExtractionDevice::new(&profile, true),
246            base_cost,
247            hints,
248        )
249        .expect("Fix: equivalent toy graph must extract");
250
251        assert_eq!(extracted.backend, "portable");
252        assert_eq!(extracted.node, Toy::Scalar);
253        assert_eq!(extracted.cost, 5);
254    }
255
256    #[test]
257    fn tensor_core_profile_extracts_fp16_variant() {
258        let (graph, root) = equivalent_toy_graph();
259        let mut profile = DeviceProfile::conservative("native");
260        profile.supports_f16 = true;
261        profile.supports_tensor_cores = true;
262
263        let extracted = extract_best_for_device(
264            &graph,
265            root,
266            ExtractionDevice::new(&profile, true),
267            base_cost,
268            hints,
269        )
270        .expect("Fix: equivalent toy graph must extract");
271
272        assert_eq!(extracted.backend, "native");
273        assert_eq!(extracted.node, Toy::TensorCore);
274        assert_eq!(extracted.cost, 4);
275    }
276
277    #[test]
278    fn several_devices_extract_from_one_saturated_graph() {
279        let (graph, root) = equivalent_toy_graph();
280        let portable = DeviceProfile::conservative("portable");
281        let mut native = DeviceProfile::conservative("native");
282        native.supports_f16 = true;
283        native.supports_tensor_cores = true;
284
285        let variants = extract_best_for_devices(
286            &graph,
287            root,
288            [
289                ExtractionDevice::new(&portable, true),
290                ExtractionDevice::new(&native, true),
291            ],
292            base_cost,
293            hints,
294        );
295
296        assert_eq!(variants.len(), 2);
297        assert_eq!(variants[0].node, Toy::Scalar);
298        assert_eq!(variants[1].node, Toy::TensorCore);
299    }
300
301    #[test]
302    fn autotune_record_biases_compile_time_constant_variant() {
303        let (graph, root) = specialized_toy_graph();
304        let profile = DeviceProfile::conservative("native");
305        let record = AutotuneRecord {
306            workgroup_size: [128, 1, 1],
307            unroll: 4,
308            tile: [0, 0, 0],
309            recorded_at: String::new(),
310        };
311
312        let extracted = extract_best_for_device(
313            &graph,
314            root,
315            ExtractionDevice::new(&profile, true).with_autotune_record(&record),
316            base_cost,
317            hints,
318        )
319        .expect("Fix: equivalent toy graph must extract");
320
321        assert_eq!(extracted.node, Toy::Specialized);
322        assert_eq!(extracted.cost, 4);
323    }
324
325    #[test]
326    fn trace_jit_biases_specialized_variant_when_speculation_pays() {
327        let (graph, root) = specialized_toy_graph();
328        let profile = DeviceProfile::conservative("native");
329        let counters = TraceJitInputs {
330            shader_hit_count: 64,
331            prediction_confidence_bps: 10_000,
332            speculative_spec_cost_ns: 1,
333            miss_cost_ns: 1_000_000,
334        };
335
336        let extracted = extract_best_for_device(
337            &graph,
338            root,
339            ExtractionDevice::new(&profile, true).with_trace_jit(counters),
340            base_cost,
341            hints,
342        )
343        .expect("Fix: equivalent toy graph must extract");
344
345        assert_eq!(extracted.node, Toy::Specialized);
346        assert_eq!(extracted.cost, 4);
347    }
348
349    #[test]
350    fn missing_root_returns_no_variant() {
351        let graph: EGraph<Toy> = EGraph::new();
352        let profile = DeviceProfile::conservative("portable");
353        let variants = extract_best_for_devices(
354            &graph,
355            EClassId(77),
356            [ExtractionDevice::new(&profile, true)],
357            base_cost,
358            hints,
359        );
360
361        assert!(variants.is_empty());
362    }
363}