Skip to main content

svod_tensor/
config.rs

1use std::sync::Arc;
2
3use snafu::ResultExt;
4use svod_device::device::Device;
5use svod_device::registry::DeviceRegistry;
6use svod_ir::DeviceSpec;
7use svod_runtime::CpuBackend;
8use svod_schedule::OptimizerConfig;
9
10use crate::error::{DeviceFactorySnafu, DeviceSnafu};
11
12/// Resolves a `DeviceSpec` into a concrete `Device` for compilation.
13///
14/// Implementations control which codegen backend is used for each device type.
15/// This enables per-call backend selection instead of relying on the
16/// `DEVICE_FACTORIES` singleton (which bakes one backend per device spec).
17pub(crate) trait DeviceResolver: Send + Sync {
18    fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>>;
19}
20
21/// Default resolver: delegates to `DEVICE_FACTORIES` singleton (reads env vars
22/// like `SVOD_CPU_BACKEND` at first device creation, then caches).
23struct EnvResolver;
24
25impl DeviceResolver for EnvResolver {
26    fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
27        svod_runtime::DEVICE_FACTORIES.device(spec, registry).context(DeviceFactorySnafu)
28    }
29}
30
31/// Creates CPU devices with a specific backend; delegates other device types
32/// to `DEVICE_FACTORIES`. This is the resolver used by `PrepareConfig::for_cpu_backend()`.
33struct CpuBackendResolver(CpuBackend);
34
35impl DeviceResolver for CpuBackendResolver {
36    fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
37        match spec {
38            DeviceSpec::Cpu => {
39                Ok(Arc::new(svod_runtime::create_cpu_device_with_backend(registry, self.0).context(DeviceSnafu)?))
40            }
41            _ => svod_runtime::DEVICE_FACTORIES.device(spec, registry).context(DeviceFactorySnafu),
42        }
43    }
44}
45
46/// Configuration for `prepare()`/`realize()` that bundles optimizer settings
47/// with device resolution (codegen backend selection).
48///
49/// Instead of relying on the `SVOD_CPU_BACKEND` env var (global mutable state),
50/// the backend is selected per-call via a [`DeviceResolver`].
51#[allow(rustdoc::private_intra_doc_links)]
52pub struct PrepareConfig {
53    pub optimizer: OptimizerConfig,
54    pub(crate) resolver: Arc<dyn DeviceResolver>,
55    /// When `true`, force the cache-cold rangeify/scheduling path even if
56    /// `SVOD_DISABLE_SCHEDULE_CACHE` is unset. Primarily useful in tests
57    /// that need to compare cache-warm vs cache-cold outputs without mutating
58    /// process-global env state.
59    pub disable_schedule_cache: bool,
60}
61
62impl std::fmt::Debug for PrepareConfig {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("PrepareConfig")
65            .field("optimizer", &self.optimizer)
66            .field("disable_schedule_cache", &self.disable_schedule_cache)
67            .finish_non_exhaustive()
68    }
69}
70
71impl Default for PrepareConfig {
72    fn default() -> Self {
73        Self { optimizer: OptimizerConfig::default(), resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
74    }
75}
76
77impl PrepareConfig {
78    /// Read both `SVOD_CPU_BACKEND` and optimizer env vars.
79    pub fn from_env() -> Self {
80        Self { optimizer: OptimizerConfig::from_env(), resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
81    }
82
83    /// Convenience constructor: specific CPU backend with optimizer settings
84    /// resolved from env (`BEAM`, `SVOD_NOOPT`, `IGNORE_BEAM_CACHE`,
85    /// `BEAM_*`, `SVOD_*`). Used by the `codegen_tests!` macro so a single
86    /// `BEAM=4 cargo test` flips every codegen-test target to BEAM
87    /// without changing test bodies.
88    pub fn for_cpu_backend(backend: CpuBackend) -> Self {
89        Self {
90            optimizer: OptimizerConfig::from_env(),
91            resolver: Arc::new(CpuBackendResolver(backend)),
92            disable_schedule_cache: false,
93        }
94    }
95
96    /// Resolve a `DeviceSpec` into a `Device` using this config's resolver.
97    pub(crate) fn resolve_device(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
98        self.resolver.resolve(spec, registry)
99    }
100}
101
102impl From<OptimizerConfig> for PrepareConfig {
103    fn from(optimizer: OptimizerConfig) -> Self {
104        Self { optimizer, resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
105    }
106}
107
108/// Generate one test per codegen backend (Clang, LLVM) from a single test body.
109///
110/// Supports three forms:
111///
112/// **Simple test** (config only, no extra params):
113/// ```ignore
114/// codegen_tests! {
115///     fn test_add(config) {
116///         let mut a = Tensor::from_slice([1.0f32, 2.0, 3.0]);
117///         a.realize_with(&config).unwrap();
118///         let result: Vec<f32> = a.as_vec().unwrap();
119///     }
120/// }
121/// // Generates: test_add::clang, test_add::llvm
122/// ```
123///
124/// **Parameterized test** (extra typed params, use with `#[test_case]`):
125/// ```ignore
126/// codegen_tests! {
127///     #[test_case(128, 0.5; "128x128")]
128///     fn test_matmul(config, size: usize, tol: f32) {
129///         let mut result = run_matmul(size);
130///         result.realize_with(&config).unwrap();
131///         assert_close(&result, tol);
132///     }
133/// }
134/// // Generates: test_matmul::clang::test_matmul, test_matmul::llvm::test_matmul
135/// ```
136///
137/// **Proptest** (property-based, params use `in` syntax):
138/// ```ignore
139/// codegen_tests! {
140///     #[proptest_config(ProptestConfig::with_cases(50))]
141///     fn test_sort_random(config, data in proptest::collection::vec(-100.0f32..100.0, 1..=16)) {
142///         let mut t = Tensor::from_slice(&data);
143///         let (sorted, _) = t.sort(-1, false).unwrap();
144///         // ...
145///     }
146/// }
147/// // Generates: test_sort_random::clang, test_sort_random::llvm
148/// ```
149#[macro_export]
150macro_rules! codegen_tests {
151    // Base case
152    () => {};
153
154    // Simple test (config only, no extra params)
155    ($(#[$meta:meta])* fn $name:ident($config:ident) $body:block $($rest:tt)*) => {
156        mod $name {
157            #[allow(unused_imports)]
158            use super::*;
159
160            #[test]
161            $(#[$meta])*
162            fn clang() {
163                ::svod_schedule::testing::setup_test_tracing();
164                let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
165                $body
166            }
167
168            #[test]
169            $(#[$meta])*
170            fn llvm() {
171                ::svod_schedule::testing::setup_test_tracing();
172                let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
173                $body
174            }
175        }
176        $crate::codegen_tests!($($rest)*);
177    };
178
179    // Proptest with config: #[proptest_config(...)] fn name(config, param in strategy) { body }
180    (#[proptest_config($($pc:tt)*)] $(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident in $strategy:expr),+ $(,)?) $body:block $($rest:tt)*) => {
181        $crate::codegen_tests!(@proptest $name, $config, [$($param in $strategy),+], $body,
182            ::proptest::test_runner::TestRunner::new($($pc)*), [$(#[$meta])*]);
183        $crate::codegen_tests!($($rest)*);
184    };
185
186    // Proptest with default config: fn name(config, param in strategy) { body }
187    ($(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident in $strategy:expr),+ $(,)?) $body:block $($rest:tt)*) => {
188        $crate::codegen_tests!(@proptest $name, $config, [$($param in $strategy),+], $body,
189            ::proptest::test_runner::TestRunner::default(), [$(#[$meta])*]);
190        $crate::codegen_tests!($($rest)*);
191    };
192
193    // Internal: proptest code generation (uses TestRunner API directly)
194    (@proptest $name:ident, $config:ident, [$($param:ident in $strategy:expr),+], $body:block, $runner:expr, [$(#[$meta:meta])*]) => {
195        mod $name {
196            #[allow(unused_imports)]
197            use super::*;
198
199            #[test]
200            #[allow(unused_parens)]
201            $(#[$meta])*
202            fn clang() {
203                ::svod_schedule::testing::setup_test_tracing();
204                let mut runner = $runner;
205                runner.run(&($($strategy),+), |($($param),+)| {
206                    let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
207                    $body
208                    Ok(())
209                }).unwrap();
210            }
211
212            #[test]
213            #[allow(unused_parens)]
214            $(#[$meta])*
215            fn llvm() {
216                ::svod_schedule::testing::setup_test_tracing();
217                let mut runner = $runner;
218                runner.run(&($($strategy),+), |($($param),+)| {
219                    let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
220                    $body
221                    Ok(())
222                }).unwrap();
223            }
224        }
225    };
226
227    // Parameterized test (extra typed params — test_case attrs expected, no #[test])
228    ($(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident: $ty:ty),+ $(,)?) $body:block $($rest:tt)*) => {
229        mod $name {
230            mod clang {
231                #[allow(unused_imports)]
232                use super::super::*;
233                use ::test_case::test_case;
234
235                $(#[$meta])*
236                fn $name($($param: $ty),+) {
237                    ::svod_schedule::testing::setup_test_tracing();
238                    let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
239                    $body
240                }
241            }
242            mod llvm {
243                #[allow(unused_imports)]
244                use super::super::*;
245                use ::test_case::test_case;
246
247                $(#[$meta])*
248                fn $name($($param: $ty),+) {
249                    ::svod_schedule::testing::setup_test_tracing();
250                    let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
251                    $body
252                }
253            }
254        }
255        $crate::codegen_tests!($($rest)*);
256    };
257}