Skip to main content

vyre_driver/strategy/
mod.rs

1//! Backend-specific lowering strategies.
2//!
3//! # Two-Layer Optimization Architecture
4//!
5//! Vyre separates optimizations into two layers with clear separation of
6//! concerns:
7//!
8//! ## Layer 1  -  IR-Level Passes (`vyre-foundation/src/optimizer/passes/`)
9//!
10//! Pure mathematical rewrites that transform `Expr → Expr` in the IR.
11//! Backend-agnostic  -  every backend benefits equally.
12//!
13//! | Pass | Example | Lives In |
14//! |------|---------|----------|
15//! | Strength reduce | `x / 7` → `mulhi(x, M) >> s` | `strength_reduce/` |
16//! | Const fold | `3 + 4` → `7` | `const_fold/` |
17//! | Shift-add decomp | `x * 5` → `(x<<2) + x` | `strength_reduce/` |
18//! | FMA synthesis | `a*b + c` → `fma(a,b,c)` | `strength_reduce/` |
19//! | Exact division | `(x*6)/3` → `x * inv(3)` | `strength_reduce/` |
20//! | Lemire remainder | `x % 7` → `lowbits(x*M)*7>>32` | `strength_reduce/` |
21//!
22//! ## Layer 2  -  Backend Lowering Strategies (this module)
23//!
24//! Target-dependent emission decisions. These don't change WHAT the program
25//! computes  -  they change HOW it's emitted for a specific chip/API.
26//!
27//! | Strategy | Backend | Effect |
28//! |----------|---------|--------|
29//! | primary-binary native multiply-high | backend | `MulHigh` → 1 instruction |
30//! | secondary-text native multiply-high | backend | `MulHigh` → 1 instruction |
31//! | 16-bit half-word decomp | target-text fallback | `MulHigh` → 14 ALU ops |
32//! | Dual-issue FP32/INT32 | capable device | Division via FP pipeline |
33//! | Matrix-core batching | capable device | Batched int8 multiply |
34//!
35//! # Adding a New Strategy
36//!
37//! 1. Implement [`crate::strategy::LoweringStrategy`] in your backend crate
38//! 2. Register it via `inventory::submit!`
39//! 3. The lowering pipeline auto-selects the highest-priority applicable
40//!    strategy based on [`vyre_foundation::validate::BackendCapabilities`]
41//!
42//! # Vyre Law Zero
43//!
44//! > Runtime performance is sacred. No avoidable runtime overhead, ever.
45//!
46//! Layer 1 runs at compile time  -  zero cost.
47//! Layer 2 runs at kernel compile time (once for the megakernel)  -  amortized to zero.
48//! At GPU runtime, only the optimal native instructions execute.
49
50use vyre_foundation::ir::{BinOp, Expr};
51use vyre_foundation::optimizer::passes::algebraic::precision_hint::{
52    PrecisionHint, TranscendentalOp,
53};
54use vyre_foundation::validate::BackendCapabilities;
55
56/// A lowered expression ready for backend emission.
57///
58/// This is the output of a [`LoweringStrategy`]. It can be either a
59/// rewritten Vyre `Expr` or a backend-specific opaque instruction
60/// sequence (represented as a tagged enum for extensibility).
61#[derive(Debug, Clone)]
62pub enum LoweredExpr {
63    /// Rewritten as a Vyre IR expression (most strategies do this).
64    Expr(Expr),
65    /// The strategy handled emission directly  -  the lowering pipeline
66    /// should not process this expression further.
67    Emitted,
68}
69
70/// A backend-specific lowering strategy.
71///
72/// Strategies are the extensibility point for target-dependent
73/// optimizations. Each strategy declares:
74/// - **what** it can optimize (via [`can_apply`](LoweringStrategy::can_apply))
75/// - **how well** (via [`priority`](LoweringStrategy::priority))
76/// - **the transformation** (via [`lower`](LoweringStrategy::lower))
77///
78/// The lowering pipeline selects the highest-priority applicable
79/// strategy for each expression.
80pub trait LoweringStrategy: Send + Sync + std::fmt::Debug {
81    /// Human-readable name for diagnostics and telemetry.
82    fn name(&self) -> &str;
83
84    /// Check whether this strategy applies given the target capabilities
85    /// and the expression being lowered.
86    fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool;
87
88    /// Priority for strategy selection. Higher = preferred.
89    ///
90    /// Guidelines:
91    /// - 100: native hardware instruction (OpUMulExtended, mul.hi.u32)
92    /// - 50: multi-instruction but optimal (dual-issue trick)
93    /// - 10: portable decomposition (16-bit arithmetic expansion)
94    fn priority(&self) -> u32;
95
96    /// Lower the given expression using this strategy.
97    ///
98    /// `left` and `right` are the operands of the binary operation.
99    /// The strategy may return a rewritten `Expr` or signal that it
100    /// handled emission directly.
101    fn lower(&self, op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr;
102}
103
104/// Select the best available strategy for the given operation.
105///
106/// Returns `None` if no registered strategy applies, in which case
107/// the lowering pipeline should use its default emission path.
108pub fn select_strategy<'a>(
109    strategies: &'a [Box<dyn LoweringStrategy>],
110    caps: &BackendCapabilities,
111    op: &BinOp,
112) -> Option<&'a dyn LoweringStrategy> {
113    strategies
114        .iter()
115        .filter(|s| s.can_apply(caps, op))
116        .max_by_key(|s| s.priority())
117        .map(|s| s.as_ref())
118}
119
120/// Concrete lower/emit plan selected from a foundation precision hint.
121#[derive(Debug, Clone, Copy, PartialEq)]
122pub enum PrecisionLoweringPlan {
123    /// Keep the default f32/device-transcendental lowering.
124    DefaultF32,
125    /// Emit this site through native f16 ALU and widen the result to f32.
126    NativeF16 {
127        /// Maximum absolute source operand carried from the foundation hint.
128        max_abs_operand: f32,
129    },
130    /// Emit a bounded polynomial for the transcendental instead of a native
131    /// device call.
132    PolynomialTranscendental {
133        /// Target operation.
134        op: TranscendentalOp,
135        /// Maximum absolute argument bound from the foundation hint.
136        argument_bound: f32,
137        /// Required backend-side polynomial degree.
138        degree: u8,
139    },
140}
141
142/// Select a backend-neutral lower/emit plan for a precision hint.
143///
144/// Foundation owns candidate discovery. This function owns the shared
145/// capability gate every emitter uses before choosing the faster code shape.
146#[must_use]
147pub fn select_precision_lowering(
148    caps: &BackendCapabilities,
149    hint: &PrecisionHint,
150) -> PrecisionLoweringPlan {
151    match hint {
152        PrecisionHint::F16Eligible { max_abs_operand } if caps.has_native_f16 => {
153            PrecisionLoweringPlan::NativeF16 {
154                max_abs_operand: *max_abs_operand,
155            }
156        }
157        PrecisionHint::TranscendentalPolynomial { op, argument_bound }
158            if caps.has_transcendental_polynomial_emit =>
159        {
160            PrecisionLoweringPlan::PolynomialTranscendental {
161                op: *op,
162                argument_bound: *argument_bound,
163                degree: polynomial_degree_for(*op, *argument_bound),
164            }
165        }
166        _ => PrecisionLoweringPlan::DefaultF32,
167    }
168}
169
170fn polynomial_degree_for(op: TranscendentalOp, argument_bound: f32) -> u8 {
171    match op {
172        TranscendentalOp::Sin => {
173            if argument_bound <= 0.25 {
174                3
175            } else {
176                5
177            }
178        }
179        TranscendentalOp::Cos => {
180            if argument_bound <= 0.25 {
181                4
182            } else {
183                6
184            }
185        }
186        TranscendentalOp::Exp | TranscendentalOp::Ln => 5,
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[derive(Debug)]
195    struct MockNativeStrategy;
196
197    impl LoweringStrategy for MockNativeStrategy {
198        fn name(&self) -> &str {
199            "mock-native"
200        }
201        fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool {
202            caps.has_mul_high && matches!(op, BinOp::MulHigh)
203        }
204        fn priority(&self) -> u32 {
205            100
206        }
207        fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
208            // In real impl: emit OpUMulExtended
209            LoweredExpr::Expr(Expr::mulhi(left.clone(), right.clone()))
210        }
211    }
212
213    #[derive(Debug)]
214    struct MockFallbackStrategy;
215
216    impl LoweringStrategy for MockFallbackStrategy {
217        fn name(&self) -> &str {
218            "mock-fallback"
219        }
220        fn can_apply(&self, _caps: &BackendCapabilities, op: &BinOp) -> bool {
221            matches!(op, BinOp::MulHigh)
222        }
223        fn priority(&self) -> u32 {
224            10
225        }
226        fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
227            // In real impl: 16-bit decomposition
228            LoweredExpr::Expr(Expr::mul(left.clone(), right.clone()))
229        }
230    }
231
232    #[test]
233    fn selects_highest_priority() {
234        let strategies: Vec<Box<dyn LoweringStrategy>> =
235            vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
236        let caps = BackendCapabilities {
237            has_mul_high: true,
238            ..Default::default()
239        };
240        let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
241        assert_eq!(selected.unwrap().name(), "mock-native");
242    }
243
244    #[test]
245    fn falls_back_when_native_unavailable() {
246        let strategies: Vec<Box<dyn LoweringStrategy>> =
247            vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
248        let caps = BackendCapabilities {
249            has_mul_high: false,
250            ..Default::default()
251        };
252        let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
253        assert_eq!(selected.unwrap().name(), "mock-fallback");
254    }
255
256    #[test]
257    fn returns_none_for_unsupported_op() {
258        let strategies: Vec<Box<dyn LoweringStrategy>> = vec![Box::new(MockNativeStrategy)];
259        let caps = BackendCapabilities {
260            has_mul_high: true,
261            ..Default::default()
262        };
263        let selected = select_strategy(&strategies, &caps, &BinOp::Add);
264        assert!(selected.is_none());
265    }
266
267    #[test]
268    fn precision_hint_selects_native_f16_when_supported() {
269        let caps = BackendCapabilities {
270            has_native_f16: true,
271            ..Default::default()
272        };
273        let plan = select_precision_lowering(
274            &caps,
275            &PrecisionHint::F16Eligible {
276                max_abs_operand: 4.0,
277            },
278        );
279        assert_eq!(
280            plan,
281            PrecisionLoweringPlan::NativeF16 {
282                max_abs_operand: 4.0
283            }
284        );
285    }
286
287    #[test]
288    fn precision_hint_keeps_f32_without_native_f16() {
289        let plan = select_precision_lowering(
290            &BackendCapabilities::default(),
291            &PrecisionHint::F16Eligible {
292                max_abs_operand: 4.0,
293            },
294        );
295        assert_eq!(plan, PrecisionLoweringPlan::DefaultF32);
296    }
297
298    #[test]
299    fn transcendental_hint_selects_polynomial_when_supported() {
300        let caps = BackendCapabilities {
301            has_transcendental_polynomial_emit: true,
302            ..Default::default()
303        };
304        let plan = select_precision_lowering(
305            &caps,
306            &PrecisionHint::TranscendentalPolynomial {
307                op: TranscendentalOp::Sin,
308                argument_bound: 0.2,
309            },
310        );
311        assert_eq!(
312            plan,
313            PrecisionLoweringPlan::PolynomialTranscendental {
314                op: TranscendentalOp::Sin,
315                argument_bound: 0.2,
316                degree: 3,
317            }
318        );
319    }
320
321    #[test]
322    fn transcendental_hint_uses_higher_degree_for_wider_sin_range() {
323        let caps = BackendCapabilities {
324            has_transcendental_polynomial_emit: true,
325            ..Default::default()
326        };
327        let plan = select_precision_lowering(
328            &caps,
329            &PrecisionHint::TranscendentalPolynomial {
330                op: TranscendentalOp::Sin,
331                argument_bound: 0.75,
332            },
333        );
334        assert_eq!(
335            plan,
336            PrecisionLoweringPlan::PolynomialTranscendental {
337                op: TranscendentalOp::Sin,
338                argument_bound: 0.75,
339                degree: 5,
340            }
341        );
342    }
343}