Skip to main content

vyre_driver/
extraction_cost.rs

1//! Device-profile-aware cost helpers for [`vyre_foundation::optimizer::eqsat::extract_best`].
2//!
3//! ROADMAP A7. The egraph extraction substrate (`extract_best`) accepts an
4//! arbitrary `Fn(&L) -> u64` cost function. Each consumer Family used to
5//! roll its own  -  passing a flat per-op cost table that ignored the
6//! current device's tensor-core throughput, hot/cold path heat, and
7//! FP16-eligibility hints.
8//!
9//! This module gives every Family one shared place to build a cost
10//! closure from `(DeviceProfile, hot_path_flag, base_cost_fn)`. The
11//! closure scales the base cost by:
12//!
13//! 1. **Hot/cold-path multiplier.** Hot-path nodes pay less per
14//!    abstract cost unit because the optimizer is willing to spend
15//!    more rewriter budget on them; cold-path nodes pay more so the
16//!    extractor prefers smaller (less optimised) representations.
17//! 2. **Tensor-core scaling for FP16-eligible ALU work.** When the
18//!    profile reports `supports_tensor_cores && supports_f16`, ALU
19//!    nodes flagged as `fp16_eligible` are scaled by the
20//!    profile's tensor-core throughput multiplier (default `0.25`  -
21//!    i.e. 4× cheaper than scalar f32) so the extractor prefers
22//!    FP16-eligible variants on supporting hardware.
23//!
24//! Every multiplier is a `f32` clamped into `[0.0, 4.0]` and applied
25//! to the base cost before truncation back to `u64`. The base cost
26//! function still drives the *shape* of the cost landscape; the
27//! profile only nudges relative weights.
28//!
29//! Pure functional value: no global state, no allocation, no I/O.
30//! Two profiles with identical capability bits produce identical
31//! closures so the extractor result is deterministic per device.
32
33use vyre_foundation::optimizer::eqsat::ENodeLang;
34
35use crate::device_profile::DeviceProfile;
36
37/// Default cost multiplier for hot-path nodes.
38///
39/// Hot-path nodes are nodes the dispatcher recently saw fire (per the
40/// I1 hot-path-hint substrate). The extractor prefers cheaper
41/// representations on the cold path and is willing to pay more
42/// extractor work on hot paths.
43pub const HOT_PATH_COST_SCALE: f32 = 0.5;
44/// Integer basis-point form of [`HOT_PATH_COST_SCALE`] used by the release
45/// extraction path.
46pub const HOT_PATH_COST_SCALE_BPS: u32 = 5_000;
47
48/// Default cost multiplier for cold-path nodes.
49pub const COLD_PATH_COST_SCALE: f32 = 1.5;
50/// Integer basis-point form of [`COLD_PATH_COST_SCALE`] used by the release
51/// extraction path.
52pub const COLD_PATH_COST_SCALE_BPS: u32 = 15_000;
53
54/// Default tensor-core throughput multiplier for FP16-eligible ALU
55/// work on a profile that reports both `supports_tensor_cores` and
56/// `supports_f16`. `0.25` = roughly 4× cheaper than f32 ALU.
57pub const TENSOR_CORE_COST_SCALE: f32 = 0.25;
58/// Integer basis-point form of [`TENSOR_CORE_COST_SCALE`] used by the release
59/// extraction path.
60pub const TENSOR_CORE_COST_SCALE_BPS: u32 = 2_500;
61
62const MAX_SCALE_BPS: u32 = 40_000;
63
64/// Per-node hint bits derived from the foundation analyses.
65///
66/// Callers populate this from the substrate they already have:
67/// `PrecisionHints::lookup(digest)` for `fp16_eligible`, the F1
68/// `vsa_specialization_key` for `compile_time_constant`. The cost
69/// helper does not compute these  -  it only consumes them.
70#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub struct NodeHints {
72    /// Foundation precision_hint analysis flagged this node as
73    /// representable in F16. The extractor will prefer this node when
74    /// the device profile reports tensor-core support.
75    pub fp16_eligible: bool,
76    /// F1 specialization detected this node's value as a compile-time
77    /// constant. Reserved for the F3 dtype-spec wiring; not yet
78    /// consumed by this helper.
79    pub compile_time_constant: bool,
80}
81
82/// Build a cost closure for `extract_best` parameterised on the
83/// neutral device profile and a hot-path flag.
84///
85/// `base_cost_fn` gives the ABSTRACT per-op cost (e.g. 1 for a const,
86/// 2 for an Add, 4 for a Div). `hint_lookup` answers per-node hint
87/// bits  -  typically a wrapper over `PrecisionHints::lookup`.
88///
89/// The returned closure is `Fn(&L) -> u64` and can be passed
90/// straight into `extract_best`. It owns its arguments by value so
91/// the closure outlives the call frame.
92#[must_use]
93pub fn device_aware_cost<L, B, H>(
94    profile: &DeviceProfile,
95    hot: bool,
96    base_cost_fn: B,
97    hint_lookup: H,
98) -> impl Fn(&L) -> u64
99where
100    L: ENodeLang,
101    B: Fn(&L) -> u64,
102    H: Fn(&L) -> NodeHints,
103{
104    let path_scale_bps = if hot {
105        HOT_PATH_COST_SCALE_BPS
106    } else {
107        COLD_PATH_COST_SCALE_BPS
108    };
109    let tensor_scale_bps = if profile.supports_tensor_cores && profile.supports_f16 {
110        TENSOR_CORE_COST_SCALE_BPS
111    } else {
112        crate::numeric::BASIS_POINTS_DENOMINATOR
113    };
114    move |node: &L| {
115        let base = base_cost_fn(node);
116        let hints = hint_lookup(node);
117        let mut scale_bps = path_scale_bps;
118        if hints.fp16_eligible {
119            scale_bps = compose_scale_basis_points(scale_bps, tensor_scale_bps);
120        }
121        scale_cost_basis_points(base, scale_bps)
122    }
123}
124
125/// Apply an integer basis-point multiplier to a `u64` cost with checked,
126/// deterministic rounding.
127///
128/// Scale is clamped to `[1, 40000]` basis points before scaling; zero
129/// falls back to the base cost to preserve the old invalid-scale contract.
130fn scale_cost_basis_points(base: u64, scale_bps: u32) -> u64 {
131    crate::numeric::scale_u64_by_basis_points_round_clamped(
132        base,
133        scale_bps,
134        base,
135        MAX_SCALE_BPS,
136        "extraction cost",
137        "driver",
138    )
139}
140
141fn compose_scale_basis_points(left_bps: u32, right_bps: u32) -> u32 {
142    crate::numeric::compose_basis_points_u32(
143        left_bps,
144        right_bps,
145        "extraction cost scale composition",
146        "driver",
147    )
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use vyre_foundation::optimizer::eqsat::{EChildren, ENodeLang};
154
155    /// Trivial language for the cost-helper tests: just a `Const(u32)`
156    /// and a synthetic `Heavy` with no children. The base cost
157    /// function distinguishes them so we can observe scaling.
158    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
159    enum Toy {
160        Const(u32),
161        Heavy,
162    }
163
164    impl ENodeLang for Toy {
165        fn children(&self) -> EChildren {
166            EChildren::new()
167        }
168        fn with_children(&self, _children: &[vyre_foundation::optimizer::eqsat::EClassId]) -> Self {
169            self.clone()
170        }
171    }
172
173    fn base_cost(node: &Toy) -> u64 {
174        match node {
175            Toy::Const(_) => 1,
176            Toy::Heavy => 100,
177        }
178    }
179
180    fn no_hints(_: &Toy) -> NodeHints {
181        NodeHints::default()
182    }
183
184    #[test]
185    fn cold_path_inflates_base_cost() {
186        let profile = DeviceProfile::conservative("test");
187        let cost = device_aware_cost(&profile, /*hot=*/ false, base_cost, no_hints);
188        assert_eq!(
189            cost(&Toy::Heavy),
190            scale_cost_basis_points(100, COLD_PATH_COST_SCALE_BPS)
191        );
192        assert_eq!(
193            cost(&Toy::Const(0)),
194            scale_cost_basis_points(1, COLD_PATH_COST_SCALE_BPS)
195        );
196    }
197
198    #[test]
199    fn hot_path_shrinks_base_cost() {
200        let profile = DeviceProfile::conservative("test");
201        let cost = device_aware_cost(&profile, /*hot=*/ true, base_cost, no_hints);
202        assert_eq!(
203            cost(&Toy::Heavy),
204            scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
205        );
206        assert_eq!(
207            cost(&Toy::Const(0)),
208            scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
209        );
210    }
211
212    #[test]
213    fn tensor_core_profile_scales_fp16_eligible_nodes() {
214        let mut profile = DeviceProfile::conservative("test");
215        profile.supports_tensor_cores = true;
216        profile.supports_f16 = true;
217        let mark_eligible = |node: &Toy| match node {
218            Toy::Heavy => NodeHints {
219                fp16_eligible: true,
220                compile_time_constant: false,
221            },
222            _ => NodeHints::default(),
223        };
224        let cost = device_aware_cost(&profile, /*hot=*/ true, base_cost, mark_eligible);
225        let expected = scale_cost_basis_points(
226            100,
227            compose_scale_basis_points(HOT_PATH_COST_SCALE_BPS, TENSOR_CORE_COST_SCALE_BPS),
228        );
229        assert_eq!(cost(&Toy::Heavy), expected);
230        // Const is not fp16-eligible  -  only hot-path scaling applies.
231        assert_eq!(
232            cost(&Toy::Const(0)),
233            scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
234        );
235    }
236
237    #[test]
238    fn no_tensor_core_support_ignores_fp16_hint() {
239        let profile = DeviceProfile::conservative("test");
240        assert!(!profile.supports_tensor_cores);
241        let mark_eligible = |_: &Toy| NodeHints {
242            fp16_eligible: true,
243            compile_time_constant: false,
244        };
245        let cost = device_aware_cost(&profile, /*hot=*/ true, base_cost, mark_eligible);
246        // FP16 hint is ignored on a profile that doesn't support tensor cores.
247        assert_eq!(
248            cost(&Toy::Heavy),
249            scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
250        );
251    }
252
253    #[test]
254    fn scale_cost_clamps_high_multiplier_basis_points() {
255        assert_eq!(scale_cost_basis_points(10, 1_000_000), 40); // 10 * 4.0 cap
256    }
257
258    #[test]
259    fn zero_basis_point_scale_preserves_invalid_scale_contract() {
260        assert_eq!(scale_cost_basis_points(7, 0), 7);
261    }
262
263    #[test]
264    fn extraction_cost_release_path_uses_integer_scaling() {
265        let source = include_str!("extraction_cost.rs");
266        let production = source
267            .split("#[cfg(test)]")
268            .next()
269            .expect("Fix: extraction-cost production source must precede tests");
270
271        assert!(
272            production.contains("scale_cost_basis_points")
273                && production.contains("compose_scale_basis_points")
274                && production.contains("crate::numeric::"),
275            "Fix: extraction cost scaling must use deterministic integer basis-point arithmetic."
276        );
277        assert!(
278            !production.contains("base as f32")
279                && !production.contains("scaled.round()")
280                && !production.contains("scale *= tensor_scale"),
281            "Fix: extraction cost release path must not use lossy float scaling."
282        );
283    }
284
285    #[test]
286    fn deterministic_for_identical_profiles() {
287        let p1 = DeviceProfile::conservative("a");
288        let p2 = DeviceProfile::conservative("b");
289        let c1 = device_aware_cost(&p1, false, base_cost, no_hints);
290        let c2 = device_aware_cost(&p2, false, base_cost, no_hints);
291        // Backend name differs but capability bits are identical → same
292        // cost output.
293        assert_eq!(c1(&Toy::Heavy), c2(&Toy::Heavy));
294        assert_eq!(c1(&Toy::Const(7)), c2(&Toy::Const(7)));
295    }
296}