1use vyre_foundation::optimizer::eqsat::ENodeLang;
34
35use crate::device_profile::DeviceProfile;
36
37pub const HOT_PATH_COST_SCALE: f32 = 0.5;
44pub const HOT_PATH_COST_SCALE_BPS: u32 = 5_000;
47
48pub const COLD_PATH_COST_SCALE: f32 = 1.5;
50pub const COLD_PATH_COST_SCALE_BPS: u32 = 15_000;
53
54pub const TENSOR_CORE_COST_SCALE: f32 = 0.25;
58pub const TENSOR_CORE_COST_SCALE_BPS: u32 = 2_500;
61
62const MAX_SCALE_BPS: u32 = 40_000;
63
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub struct NodeHints {
72 pub fp16_eligible: bool,
76 pub compile_time_constant: bool,
80}
81
82#[must_use]
93pub fn device_aware_cost<L, B, H>(
94 profile: &DeviceProfile,
95 hot: bool,
96 base_cost_fn: B,
97 hint_lookup: H,
98) -> impl Fn(&L) -> u64
99where
100 L: ENodeLang,
101 B: Fn(&L) -> u64,
102 H: Fn(&L) -> NodeHints,
103{
104 let path_scale_bps = if hot {
105 HOT_PATH_COST_SCALE_BPS
106 } else {
107 COLD_PATH_COST_SCALE_BPS
108 };
109 let tensor_scale_bps = if profile.supports_tensor_cores && profile.supports_f16 {
110 TENSOR_CORE_COST_SCALE_BPS
111 } else {
112 crate::numeric::BASIS_POINTS_DENOMINATOR
113 };
114 move |node: &L| {
115 let base = base_cost_fn(node);
116 let hints = hint_lookup(node);
117 let mut scale_bps = path_scale_bps;
118 if hints.fp16_eligible {
119 scale_bps = compose_scale_basis_points(scale_bps, tensor_scale_bps);
120 }
121 scale_cost_basis_points(base, scale_bps)
122 }
123}
124
125fn scale_cost_basis_points(base: u64, scale_bps: u32) -> u64 {
131 crate::numeric::scale_u64_by_basis_points_round_clamped(
132 base,
133 scale_bps,
134 base,
135 MAX_SCALE_BPS,
136 "extraction cost",
137 "driver",
138 )
139}
140
141fn compose_scale_basis_points(left_bps: u32, right_bps: u32) -> u32 {
142 crate::numeric::compose_basis_points_u32(
143 left_bps,
144 right_bps,
145 "extraction cost scale composition",
146 "driver",
147 )
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use vyre_foundation::optimizer::eqsat::{EChildren, ENodeLang};
154
155 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
159 enum Toy {
160 Const(u32),
161 Heavy,
162 }
163
164 impl ENodeLang for Toy {
165 fn children(&self) -> EChildren {
166 EChildren::new()
167 }
168 fn with_children(&self, _children: &[vyre_foundation::optimizer::eqsat::EClassId]) -> Self {
169 self.clone()
170 }
171 }
172
173 fn base_cost(node: &Toy) -> u64 {
174 match node {
175 Toy::Const(_) => 1,
176 Toy::Heavy => 100,
177 }
178 }
179
180 fn no_hints(_: &Toy) -> NodeHints {
181 NodeHints::default()
182 }
183
184 #[test]
185 fn cold_path_inflates_base_cost() {
186 let profile = DeviceProfile::conservative("test");
187 let cost = device_aware_cost(&profile, false, base_cost, no_hints);
188 assert_eq!(
189 cost(&Toy::Heavy),
190 scale_cost_basis_points(100, COLD_PATH_COST_SCALE_BPS)
191 );
192 assert_eq!(
193 cost(&Toy::Const(0)),
194 scale_cost_basis_points(1, COLD_PATH_COST_SCALE_BPS)
195 );
196 }
197
198 #[test]
199 fn hot_path_shrinks_base_cost() {
200 let profile = DeviceProfile::conservative("test");
201 let cost = device_aware_cost(&profile, true, base_cost, no_hints);
202 assert_eq!(
203 cost(&Toy::Heavy),
204 scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
205 );
206 assert_eq!(
207 cost(&Toy::Const(0)),
208 scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
209 );
210 }
211
212 #[test]
213 fn tensor_core_profile_scales_fp16_eligible_nodes() {
214 let mut profile = DeviceProfile::conservative("test");
215 profile.supports_tensor_cores = true;
216 profile.supports_f16 = true;
217 let mark_eligible = |node: &Toy| match node {
218 Toy::Heavy => NodeHints {
219 fp16_eligible: true,
220 compile_time_constant: false,
221 },
222 _ => NodeHints::default(),
223 };
224 let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
225 let expected = scale_cost_basis_points(
226 100,
227 compose_scale_basis_points(HOT_PATH_COST_SCALE_BPS, TENSOR_CORE_COST_SCALE_BPS),
228 );
229 assert_eq!(cost(&Toy::Heavy), expected);
230 assert_eq!(
232 cost(&Toy::Const(0)),
233 scale_cost_basis_points(1, HOT_PATH_COST_SCALE_BPS)
234 );
235 }
236
237 #[test]
238 fn no_tensor_core_support_ignores_fp16_hint() {
239 let profile = DeviceProfile::conservative("test");
240 assert!(!profile.supports_tensor_cores);
241 let mark_eligible = |_: &Toy| NodeHints {
242 fp16_eligible: true,
243 compile_time_constant: false,
244 };
245 let cost = device_aware_cost(&profile, true, base_cost, mark_eligible);
246 assert_eq!(
248 cost(&Toy::Heavy),
249 scale_cost_basis_points(100, HOT_PATH_COST_SCALE_BPS)
250 );
251 }
252
253 #[test]
254 fn scale_cost_clamps_high_multiplier_basis_points() {
255 assert_eq!(scale_cost_basis_points(10, 1_000_000), 40); }
257
258 #[test]
259 fn zero_basis_point_scale_preserves_invalid_scale_contract() {
260 assert_eq!(scale_cost_basis_points(7, 0), 7);
261 }
262
263 #[test]
264 fn extraction_cost_release_path_uses_integer_scaling() {
265 let source = include_str!("extraction_cost.rs");
266 let production = source
267 .split("#[cfg(test)]")
268 .next()
269 .expect("Fix: extraction-cost production source must precede tests");
270
271 assert!(
272 production.contains("scale_cost_basis_points")
273 && production.contains("compose_scale_basis_points")
274 && production.contains("crate::numeric::"),
275 "Fix: extraction cost scaling must use deterministic integer basis-point arithmetic."
276 );
277 assert!(
278 !production.contains("base as f32")
279 && !production.contains("scaled.round()")
280 && !production.contains("scale *= tensor_scale"),
281 "Fix: extraction cost release path must not use lossy float scaling."
282 );
283 }
284
285 #[test]
286 fn deterministic_for_identical_profiles() {
287 let p1 = DeviceProfile::conservative("a");
288 let p2 = DeviceProfile::conservative("b");
289 let c1 = device_aware_cost(&p1, false, base_cost, no_hints);
290 let c2 = device_aware_cost(&p2, false, base_cost, no_hints);
291 assert_eq!(c1(&Toy::Heavy), c2(&Toy::Heavy));
294 assert_eq!(c1(&Toy::Const(7)), c2(&Toy::Const(7)));
295 }
296}