Skip to main content

svod_device/
device.rs

1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use svod_dtype::DeviceSpec;
16use svod_ir::{BinaryOp, ConstValue, Op, UOp, UnaryOp};
17
18use crate::allocator::Allocator;
19use crate::error::{Error, Result};
20
21/// A compiled, executable kernel program.
22///
23/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
24/// Each backend implements this to provide unified execution interface.
25///
26/// Implementations must be stateless and reentrant from the host perspective.
27/// The runtime caches and shares programs across execution plans, and may invoke
28/// the same program from multiple host threads when dependency analysis proves
29/// the buffer accesses are independent.
30///
31/// # Tinygrad Alignment
32///
33/// This trait follows Tinygrad's `Program` interface where variable values are
34/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
35/// The order matches `var_names` in `CompiledSpec`.
36pub trait Program: Send + Sync {
37    /// Execute the kernel with given buffers and variable values.
38    ///
39    /// # Arguments
40    ///
41    /// * `buffers` - Raw pointers to buffer data (input and output buffers)
42    /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
43    /// * `global_size` - Global work size (for GPU backends, None for CPU)
44    /// * `local_size` - Local work size (for GPU backends, None for CPU)
45    ///
46    /// # Safety
47    ///
48    /// This is unsafe because:
49    /// - Buffer pointers must be valid and properly aligned
50    /// - Buffer sizes must match what the kernel expects
51    /// - Caller must ensure no data races during execution
52    unsafe fn execute(
53        &self,
54        buffers: &[*mut u8],
55        vals: &[i64],
56        global_size: Option<[usize; 3]>,
57        local_size: Option<[usize; 3]>,
58    ) -> Result<()>;
59
60    /// Get the kernel name (for debugging/profiling).
61    fn name(&self) -> &str;
62}
63
64/// Compilation result carrying source (JIT) or bytes (AOT).
65///
66/// Different backends need different information:
67/// - LLVM JIT: needs source code to compile during runtime
68/// - CUDA: needs PTX/CUBIN bytes to load
69/// - Metal: needs metallib bytes to load
70///
71/// This design allows the RuntimeFactory to access whatever it needs
72/// without requiring separate code paths for JIT vs AOT backends.
73#[derive(Debug, Clone)]
74pub struct CompiledSpec {
75    /// Entry point function name
76    pub name: String,
77
78    /// Source code (for JIT backends like LLVM)
79    /// Set to Some(...) for LLVM JIT, None for AOT backends
80    pub src: Option<String>,
81
82    /// Compiled bytes (for AOT backends like CUDA/Metal)
83    /// Empty for LLVM JIT, populated for AOT backends
84    pub bytes: Vec<u8>,
85
86    /// Original AST for cache key construction via hash consing
87    pub ast: Arc<UOp>,
88
89    /// Variable names in order for populating vars array at runtime.
90    /// Includes runtime variables such as core_id.
91    pub var_names: Vec<String>,
92
93    /// Symbolic global work size for dispatch.
94    pub global_size: [Arc<UOp>; 3],
95
96    /// Symbolic local work size for dispatch. None means direct global-id execution.
97    pub local_size: Option<[Arc<UOp>; 3]>,
98
99    /// Number of buffer arguments (for CIF construction at compile time).
100    pub buf_count: usize,
101}
102
103impl CompiledSpec {
104    /// Create a new CompiledSpec for JIT backends (source-based).
105    pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
106        Self {
107            name,
108            src: Some(src),
109            bytes: Vec::new(),
110            ast,
111            var_names: Vec::new(),
112            global_size: default_launch_size(),
113            local_size: Some(default_launch_size()),
114            buf_count,
115        }
116    }
117
118    /// Create a new CompiledSpec for AOT backends (bytecode-based).
119    pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
120        Self {
121            name,
122            src: None,
123            bytes,
124            ast,
125            var_names: Vec::new(),
126            global_size: default_launch_size(),
127            local_size: Some(default_launch_size()),
128            buf_count: 0,
129        }
130    }
131
132    /// Create a new CompiledSpec with work sizes for JIT backends.
133    pub fn from_source_with_sizes(
134        name: String,
135        src: String,
136        ast: Arc<UOp>,
137        global_size: [usize; 3],
138        local_size: Option<[usize; 3]>,
139        buf_count: usize,
140    ) -> Self {
141        Self {
142            name,
143            src: Some(src),
144            bytes: Vec::new(),
145            ast,
146            var_names: Vec::new(),
147            global_size: concrete_launch_size(global_size),
148            local_size: local_size.map(concrete_launch_size),
149            buf_count,
150        }
151    }
152}
153
154/// Concrete launch dimensions passed to backend runtimes.
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub struct ConcreteLaunchDims {
157    pub global_size: [usize; 3],
158    pub local_size: Option<[usize; 3]>,
159}
160
161fn default_launch_size() -> [Arc<UOp>; 3] {
162    [UOp::index_const(1), UOp::index_const(1), UOp::index_const(1)]
163}
164
165fn concrete_launch_size(size: [usize; 3]) -> [Arc<UOp>; 3] {
166    [UOp::index_const(size[0] as i64), UOp::index_const(size[1] as i64), UOp::index_const(size[2] as i64)]
167}
168
169fn const_value_to_i64(value: ConstValue) -> Result<i64> {
170    match value {
171        ConstValue::Int(v) => Ok(v),
172        ConstValue::UInt(v) => i64::try_from(v)
173            .map_err(|_| Error::Runtime { message: format!("launch-size constant {v} does not fit i64") }),
174        ConstValue::Bool(v) => Ok(i64::from(v)),
175        ConstValue::Float(v) => {
176            Err(Error::Runtime { message: format!("launch-size expression must be integer, got float constant {v}") })
177        }
178    }
179}
180
181fn validate_var_bound(name: &str, value: i64, min_val: i64, max_val: i64) -> Result<()> {
182    if value < min_val || value > max_val {
183        return Err(Error::Runtime {
184            message: format!("variable {name}={value} is outside bounds [{min_val}, {max_val}]"),
185        });
186    }
187    Ok(())
188}
189
190fn checked_launch_binary(op: BinaryOp, lhs: i64, rhs: i64) -> Result<i64> {
191    let value = match op {
192        BinaryOp::Add => lhs.checked_add(rhs),
193        BinaryOp::Sub => lhs.checked_sub(rhs),
194        BinaryOp::Mul => lhs.checked_mul(rhs),
195        BinaryOp::Idiv => (rhs != 0).then(|| lhs.checked_div(rhs)).flatten(),
196        BinaryOp::Mod => (rhs != 0).then(|| lhs.checked_rem(rhs)).flatten(),
197        BinaryOp::Max => Some(lhs.max(rhs)),
198        _ => {
199            return Err(Error::Runtime { message: format!("unsupported binary op in launch-size expression: {op:?}") });
200        }
201    };
202
203    value.ok_or_else(|| Error::Runtime { message: format!("invalid launch-size arithmetic: {lhs} {op:?} {rhs}") })
204}
205
206fn eval_launch_expr(expr: &Arc<UOp>, vars: &HashMap<&str, i64>) -> Result<i64> {
207    match expr.op() {
208        Op::Const(value) => const_value_to_i64(value.0),
209        Op::DefineVar { name, min_val, max_val } => {
210            let value = vars.get(name.as_str()).copied().ok_or_else(|| Error::Runtime {
211                message: format!("missing runtime value for launch-size variable {name}"),
212            })?;
213            validate_var_bound(name, value, *min_val, *max_val)?;
214            Ok(value)
215        }
216        Op::Bind { var, value } => {
217            let bound = eval_launch_expr(value, vars)?;
218            if let Op::DefineVar { name, min_val, max_val } = var.op() {
219                validate_var_bound(name, bound, *min_val, *max_val)?;
220            }
221            Ok(bound)
222        }
223        Op::Binary(op, lhs, rhs) => {
224            checked_launch_binary(*op, eval_launch_expr(lhs, vars)?, eval_launch_expr(rhs, vars)?)
225        }
226        Op::Unary(UnaryOp::Neg, src) => eval_launch_expr(src, vars)?
227            .checked_neg()
228            .ok_or_else(|| Error::Runtime { message: "invalid launch-size negation overflow".to_string() }),
229        Op::Unary(UnaryOp::Abs, src) => eval_launch_expr(src, vars)?
230            .checked_abs()
231            .ok_or_else(|| Error::Runtime { message: "invalid launch-size abs overflow".to_string() }),
232        Op::Cast { src, .. } | Op::BitCast { src, .. } | Op::After { passthrough: src, .. } => {
233            eval_launch_expr(src, vars)
234        }
235        other => Err(Error::Runtime { message: format!("unsupported launch-size expression op: {other:?}") }),
236    }
237}
238
239fn eval_launch_size(size: &[Arc<UOp>; 3], vars: &HashMap<&str, i64>) -> Result<[usize; 3]> {
240    let mut out = [1usize; 3];
241    for (idx, expr) in size.iter().enumerate() {
242        let value = eval_launch_expr(expr, vars)?;
243        if value <= 0 {
244            return Err(Error::Runtime {
245                message: format!("launch dimension {idx} evaluated to non-positive value {value}"),
246            });
247        }
248        out[idx] = usize::try_from(value).map_err(|_| Error::Runtime {
249            message: format!("launch dimension {idx} value {value} does not fit usize"),
250        })?;
251    }
252    Ok(out)
253}
254
255/// A compiler that transforms source code into a compiled specification.
256///
257/// This trait abstracts over different compilation backends:
258/// - LLVM: IR validation (JIT compiles at runtime)
259/// - CUDA: CUDA C -> PTX/CUBIN
260/// - Metal: Metal Shading Language -> metallib
261/// - WebGPU: WGSL -> SPIR-V
262pub trait Compiler: Send + Sync {
263    /// Compile a program specification into executable form.
264    ///
265    /// # Arguments
266    ///
267    /// * `spec` - The program specification containing source code and metadata
268    ///
269    /// # Returns
270    ///
271    /// A CompiledSpec containing:
272    /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
273    /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
274    ///
275    /// # Examples
276    ///
277    /// JIT backend (LLVM):
278    /// ```ignore
279    /// let compiled = compiler.compile(&spec)?;
280    /// assert!(compiled.src.is_some());
281    /// assert!(compiled.bytes.is_empty());
282    /// ```
283    ///
284    /// AOT backend (CUDA):
285    /// ```ignore
286    /// let compiled = compiler.compile(&spec)?;
287    /// assert!(compiled.src.is_none());
288    /// assert!(!compiled.bytes.is_empty());
289    /// ```
290    fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
291
292    /// Cache key identifying this compiler backend.
293    ///
294    /// Used to differentiate compiled artifacts when the same device type
295    /// can have multiple compiler backends (e.g., clang vs llvm-jit).
296    fn cache_key(&self) -> &'static str;
297}
298
299/// A renderer that transforms UOp graphs into source code.
300///
301/// This trait abstracts over different code generation backends:
302/// - LLVM IR generator
303/// - CUDA C generator
304/// - Metal Shading Language generator
305/// - WGSL generator
306pub trait Renderer: Send + Sync {
307    /// Render a UOp graph into source code.
308    ///
309    /// # Arguments
310    ///
311    /// * `ast` - The kernel AST (UOp graph rooted at a CALL body such as SINK/PROGRAM)
312    /// * `name` - Optional kernel name for debugging (e.g., "r_g16l16R32u4").
313    ///   Falls back to "kernel" if None.
314    ///
315    /// # Returns
316    ///
317    /// A ProgramSpec containing:
318    /// - Generated source code
319    /// - Entry point name
320    /// - Variable list
321    /// - Work sizes (for GPU backends)
322    fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
323
324    /// Get the device spec for this renderer.
325    ///
326    /// This is used for cache key construction and device selection.
327    fn device(&self) -> &DeviceSpec;
328
329    /// Returns decomposition patterns for operations this backend doesn't support.
330    ///
331    /// This is used by the realization pass to decompose complex operations
332    /// into simpler primitives before rendering.
333    ///
334    /// # Default Implementation
335    ///
336    /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
337    /// Backends that don't support certain operations (e.g., transcendentals)
338    /// should override this to return appropriate patterns.
339    fn decompositor(&self) -> Option<svod_ir::pattern::TypedPatternMatcher<()>> {
340        None
341    }
342}
343
344/// A factory function that creates executable Programs from a compiled specification.
345///
346/// This is a function pointer that wraps the backend-specific loader:
347/// - LLVM: Extract source from CompiledSpec and JIT compile
348/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
349/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
350/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
351///
352/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
353/// allowing each backend to access what it needs.
354pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
355
356/// A (Renderer, Compiler) pair for a specific backend.
357///
358/// Devices can have multiple compiler pairs (e.g., different optimization levels).
359pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
360
361/// A device that owns renderer, compiler, runtime, and allocator.
362///
363/// This follows Tinygrad's architecture where a Device is a complete
364/// compilation + execution unit for a specific backend.
365///
366/// # Example
367///
368/// ```ignore
369/// let cpu_device = create_cpu_device()?;
370/// let spec = cpu_device.renderer.render(&kernel_ast, Some("E_L3"))?;
371/// let compiled = cpu_device.compiler.compile(&spec)?;
372/// let program = (cpu_device.runtime)(&compiled)?;
373/// unsafe { program.execute(&buffers, &vals, None, None)?; }
374/// ```
375pub struct Device {
376    /// Device specification
377    pub device: DeviceSpec,
378
379    /// Memory allocator for this device
380    pub allocator: Arc<dyn Allocator>,
381
382    /// Available (renderer, compiler) pairs for this device
383    ///
384    /// Most devices have one pair, but some may have multiple
385    /// (e.g., different optimization levels or compilation modes).
386    pub compilers: Vec<CompilerPair>,
387
388    /// Primary renderer for this device
389    ///
390    /// This is typically `compilers[0].0`, stored separately for convenience.
391    pub renderer: Arc<dyn Renderer>,
392
393    /// Primary compiler for this device
394    ///
395    /// This is typically `compilers[0].1`, stored separately for convenience.
396    pub compiler: Arc<dyn Compiler>,
397
398    /// Runtime factory for creating executable programs
399    ///
400    /// Takes (entry_point, compiled_bytes) and returns a Program.
401    pub runtime: RuntimeFactory,
402}
403
404impl Device {
405    /// Create a new device with a single compiler pair.
406    ///
407    /// This is a convenience constructor for the common case where
408    /// a device has only one renderer/compiler combination.
409    pub fn new(
410        device: DeviceSpec,
411        allocator: Arc<dyn Allocator>,
412        renderer: Arc<dyn Renderer>,
413        compiler: Arc<dyn Compiler>,
414        runtime: RuntimeFactory,
415    ) -> Self {
416        let compilers = vec![(renderer.clone(), compiler.clone())];
417        Self { device, allocator, compilers, renderer, compiler, runtime }
418    }
419
420    /// Get the base device key (strips device ID).
421    ///
422    /// Used for compiled byte cache sharing across device instances.
423    /// Examples:
424    /// - DeviceSpec::Cpu -> "CPU"
425    /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
426    /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
427    /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
428    ///
429    /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
430    pub fn base_device_key(&self) -> &'static str {
431        self.device.base_type()
432    }
433}
434
435/// Program specification containing source code and metadata.
436///
437/// This is returned by Renderer::render() and consumed by Compiler::compile().
438/// It bridges the gap between UOp graphs and compiled executables.
439///
440/// # Tinygrad Alignment
441///
442/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
443/// - `globals`: Buffer indices from PARAM ops
444/// - `outs`: Output buffer indices (written by STORE ops)
445/// - `ins`: Input buffer indices (read by LOAD ops)
446#[derive(Debug, Clone)]
447pub struct ProgramSpec {
448    /// Kernel name (for debugging/profiling)
449    pub name: String,
450
451    /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
452    pub src: String,
453
454    /// Device specification
455    pub device: DeviceSpec,
456
457    /// Original AST (for cache key construction via hash consing)
458    pub ast: Arc<UOp>,
459
460    /// Symbolic global work size.
461    pub global_size: [Arc<UOp>; 3],
462
463    /// Symbolic local work size. None means direct global-id execution.
464    pub local_size: Option<[Arc<UOp>; 3]>,
465
466    /// Variable list (for symbolic shapes/strides)
467    pub vars: Vec<Variable>,
468
469    /// Variable names in order for populating vars array at runtime.
470    /// Includes runtime variables such as core_id.
471    pub var_names: Vec<String>,
472
473    /// Global buffer indices (from PARAM slot values).
474    /// Matches Tinygrad's `globals` field.
475    pub globals: Vec<usize>,
476
477    /// Output buffer indices (written by STORE ops).
478    /// Matches Tinygrad's `outs` field.
479    pub outs: Vec<usize>,
480
481    /// Input buffer indices (read by LOAD ops, excluding outputs).
482    /// Matches Tinygrad's `ins` field.
483    pub ins: Vec<usize>,
484
485    /// Number of buffer arguments (for CIF construction at compile time).
486    pub buf_count: usize,
487}
488
489#[derive(Debug)]
490struct DerivedProgramMetadata {
491    vars: Vec<Variable>,
492    var_names: Vec<String>,
493    globals: Vec<usize>,
494    outs: Vec<usize>,
495    ins: Vec<usize>,
496    global_size: [Arc<UOp>; 3],
497    local_size: Option<[Arc<UOp>; 3]>,
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501enum LaunchDimKind {
502    Global,
503    Local,
504    DirectGlobal,
505}
506
507impl ProgramSpec {
508    /// Create a new program specification.
509    pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
510        Self {
511            name,
512            src,
513            device,
514            ast,
515            global_size: default_launch_size(),
516            local_size: Some(default_launch_size()),
517            vars: Vec::new(),
518            var_names: Vec::new(),
519            globals: Vec::new(),
520            outs: Vec::new(),
521            ins: Vec::new(),
522            buf_count: 0,
523        }
524    }
525
526    /// Add a variable to the program.
527    pub fn add_var(&mut self, var: Variable) {
528        self.vars.push(var);
529    }
530
531    /// Set work sizes for GPU execution.
532    pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
533        self.global_size = concrete_launch_size(global);
534        self.local_size = Some(concrete_launch_size(local));
535    }
536
537    /// Set symbolic work sizes for replay with runtime variables.
538    pub fn set_launch_dims(&mut self, global: [Arc<UOp>; 3], local: Option<[Arc<UOp>; 3]>) {
539        self.global_size = global;
540        self.local_size = local;
541    }
542
543    /// Evaluate symbolic launch dimensions using runtime variable values.
544    pub fn launch_dims(&self, var_vals: &HashMap<&str, i64>) -> Result<ConcreteLaunchDims> {
545        Self::resolve_launch_dims(&self.global_size, self.local_size.as_ref(), var_vals)
546    }
547
548    /// Evaluate launch dimensions stored outside a full ProgramSpec.
549    pub fn resolve_launch_dims(
550        global_size: &[Arc<UOp>; 3],
551        local_size: Option<&[Arc<UOp>; 3]>,
552        var_vals: &HashMap<&str, i64>,
553    ) -> Result<ConcreteLaunchDims> {
554        Ok(ConcreteLaunchDims {
555            global_size: eval_launch_size(global_size, var_vals)?,
556            local_size: local_size.map(|local| eval_launch_size(local, var_vals)).transpose()?,
557        })
558    }
559
560    /// Set variable names for populating vars array at runtime.
561    pub fn set_var_names(&mut self, var_names: Vec<String>) {
562        self.var_names = var_names;
563    }
564
565    /// Set buffer metadata (globals, outs, ins).
566    pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
567        self.globals = globals;
568        self.outs = outs;
569        self.ins = ins;
570    }
571
572    /// Derive and apply metadata from `self.ast`.
573    ///
574    /// This mirrors Tinygrad-style program metadata extraction from the kernel
575    /// graph and keeps renderer wrappers aligned on one metadata path.
576    pub fn apply_derived_metadata_from_ast(&mut self) {
577        let derived = Self::derive_metadata_from_sink(&self.ast);
578        self.globals = derived.globals;
579        self.outs = derived.outs;
580        self.ins = derived.ins;
581        if self.vars.is_empty() {
582            self.vars = derived.vars;
583        }
584        if self.var_names.is_empty() {
585            self.var_names = derived.var_names;
586        }
587        if self.buf_count == 0 {
588            self.buf_count = self.globals.len();
589        }
590        self.global_size = derived.global_size;
591        self.local_size = derived.local_size;
592    }
593
594    fn special_launch_axis(name: &str) -> Option<(LaunchDimKind, usize)> {
595        let kind = match name.chars().next()? {
596            'g' => LaunchDimKind::Global,
597            'l' => LaunchDimKind::Local,
598            'i' => LaunchDimKind::DirectGlobal,
599            _ => return None,
600        };
601        let suffix_start = name.rfind(|ch: char| !ch.is_ascii_digit()).map(|idx| idx + 1).unwrap_or(0);
602        if suffix_start == name.len() {
603            return None;
604        }
605        let axis = name[suffix_start..].parse::<usize>().ok()?;
606        (axis < 3).then_some((kind, axis))
607    }
608
609    fn is_const_one(uop: &Arc<UOp>) -> bool {
610        matches!(uop.op(), Op::Const(value) if matches!(value.0, ConstValue::Int(1) | ConstValue::UInt(1)))
611    }
612
613    fn has_non_default_launch_dims(&self) -> bool {
614        !self.global_size.iter().all(Self::is_const_one)
615            || !matches!(&self.local_size, Some(local) if local.iter().all(Self::is_const_one))
616    }
617
618    fn extract_param_slot_from_index(index: &Arc<UOp>) -> Option<usize> {
619        fn slot_from_buffer(buffer: &Arc<UOp>) -> Option<usize> {
620            if let Op::Param { slot, device: None, .. } = buffer.op() { Some(*slot) } else { None }
621        }
622
623        match index.op() {
624            Op::Index { buffer, .. } => slot_from_buffer(buffer),
625            Op::Cast { src, .. } => match src.op() {
626                Op::Index { buffer, .. } => slot_from_buffer(buffer),
627                _ => None,
628            },
629            _ => None,
630        }
631    }
632
633    fn derive_metadata_from_sink(sink: &Arc<UOp>) -> DerivedProgramMetadata {
634        let mut vars = Vec::new();
635        let mut globals = Vec::new();
636        let mut outs = Vec::new();
637        let mut ins = Vec::new();
638        let mut global_size = default_launch_size();
639        let mut local_size = Some(default_launch_size());
640
641        for node in sink.toposort() {
642            match node.op() {
643                Op::DefineVar { name, min_val, max_val } => {
644                    vars.push(Variable::new(name.clone(), *min_val, *max_val));
645                    if name == "core_id" {
646                        global_size[0] = UOp::index_const(max_val.saturating_add(1));
647                    }
648                }
649                Op::Param { slot, device: None, .. } => {
650                    globals.push(*slot);
651                }
652                Op::Special { end, name } => {
653                    if let Some((kind, axis)) = Self::special_launch_axis(name) {
654                        match kind {
655                            LaunchDimKind::Global => global_size[axis] = end.clone(),
656                            LaunchDimKind::Local => {
657                                local_size.get_or_insert_with(default_launch_size)[axis] = end.clone()
658                            }
659                            LaunchDimKind::DirectGlobal => {
660                                global_size[axis] = end.clone();
661                                local_size = None;
662                            }
663                        }
664                    }
665                }
666                Op::Store { index, .. } => {
667                    if let Some(slot) = Self::extract_param_slot_from_index(index) {
668                        outs.push(slot);
669                    }
670                }
671                Op::Load { index, .. } => {
672                    if let Some(slot) = Self::extract_param_slot_from_index(index) {
673                        ins.push(slot);
674                    }
675                }
676                _ => {}
677            }
678        }
679
680        vars.sort_by(|a, b| a.name.cmp(&b.name));
681        vars.dedup_by(|a, b| a.name == b.name);
682        let var_names = vars.iter().map(|v| v.name.clone()).collect();
683
684        globals.sort_unstable();
685        globals.dedup();
686
687        outs.sort_unstable();
688        outs.dedup();
689
690        ins.sort_unstable();
691        ins.dedup();
692
693        DerivedProgramMetadata { vars, var_names, globals, outs, ins, global_size, local_size }
694    }
695
696    /// Build a ProgramSpec from a PROGRAM UOp state.
697    ///
698    /// Validates PROGRAM stage shape and derives metadata from PROGRAM itself.
699    pub fn from_uop(program: &Arc<UOp>) -> Result<Self> {
700        let Op::Program { sink, device, linear, source, binary } = program.op() else {
701            return Err(Error::Runtime { message: format!("expected PROGRAM op, got {:?}", program.op()) });
702        };
703
704        if !matches!(sink.op(), Op::Sink { .. }) {
705            return Err(Error::Runtime { message: format!("PROGRAM sink stage must be SINK op, got {:?}", sink.op()) });
706        }
707
708        let device_spec = match device.op() {
709            Op::Device(spec) => spec.clone(),
710            _ => {
711                return Err(Error::Runtime {
712                    message: format!("PROGRAM device must be DEVICE op, got {:?}", device.op()),
713                });
714            }
715        };
716
717        let linear =
718            linear.as_ref().ok_or_else(|| Error::Runtime { message: "PROGRAM missing LINEAR stage".to_string() })?;
719        if !matches!(linear.op(), Op::Linear { .. }) {
720            return Err(Error::Runtime {
721                message: format!("PROGRAM linear stage must be LINEAR op, got {:?}", linear.op()),
722            });
723        }
724
725        let source =
726            source.as_ref().ok_or_else(|| Error::Runtime { message: "PROGRAM missing SOURCE stage".to_string() })?;
727        let source_code = match source.op() {
728            Op::Source { code } => code.clone(),
729            _ => {
730                return Err(Error::Runtime {
731                    message: format!("PROGRAM source stage must be SOURCE op, got {:?}", source.op()),
732                });
733            }
734        };
735
736        if let Some(binary) = binary
737            && !matches!(binary.op(), Op::ProgramBinary { .. })
738        {
739            return Err(Error::Runtime {
740                message: format!("PROGRAM binary stage must be ProgramBinary op, got {:?}", binary.op()),
741            });
742        }
743
744        let derived = Self::derive_metadata_from_sink(sink);
745        let meta = program.metadata::<ProgramSpec>();
746
747        let name = meta.as_ref().map(|m| m.name.clone()).unwrap_or_else(|| "kernel".to_string());
748
749        let mut spec = Self::new(name, source_code, device_spec, sink.clone());
750        spec.vars = meta.as_ref().map(|m| m.vars.clone()).filter(|vars| !vars.is_empty()).unwrap_or(derived.vars);
751        spec.var_names =
752            meta.as_ref().map(|m| m.var_names.clone()).filter(|names| !names.is_empty()).unwrap_or(derived.var_names);
753        spec.globals =
754            meta.as_ref().map(|m| m.globals.clone()).filter(|globals| !globals.is_empty()).unwrap_or(derived.globals);
755        spec.outs = meta.as_ref().map(|m| m.outs.clone()).filter(|outs| !outs.is_empty()).unwrap_or(derived.outs);
756        spec.ins = meta.as_ref().map(|m| m.ins.clone()).filter(|ins| !ins.is_empty()).unwrap_or(derived.ins);
757        spec.buf_count = meta.as_ref().map(|m| m.buf_count).filter(|count| *count > 0).unwrap_or(spec.globals.len());
758        let meta_launch = meta.as_ref().filter(|m| m.has_non_default_launch_dims());
759        spec.global_size = meta_launch.map(|m| m.global_size.clone()).unwrap_or(derived.global_size);
760        spec.local_size = meta_launch.map(|m| m.local_size.clone()).unwrap_or(derived.local_size);
761
762        Ok(spec)
763    }
764}
765
766/// A variable in the kernel (for symbolic shapes/strides).
767///
768/// Variables represent symbolic values that are bound at kernel execution time.
769/// Examples:
770/// - Shape dimensions that vary per input
771/// - Stride values computed from shapes
772/// - Loop bounds determined by input sizes
773#[derive(Debug, Clone)]
774pub struct Variable {
775    /// Variable name (must be unique within the kernel)
776    pub name: String,
777
778    /// Minimum value (for range validation)
779    pub min: i64,
780
781    /// Maximum value (for range validation)
782    pub max: i64,
783}
784
785impl Variable {
786    /// Create a new variable.
787    pub fn new(name: String, min: i64, max: i64) -> Self {
788        Self { name, min, max }
789    }
790}