1use std::collections::HashMap;
13use std::sync::Arc;
14
15use svod_dtype::DeviceSpec;
16use svod_ir::{BinaryOp, ConstValue, Op, UOp, UnaryOp};
17
18use crate::allocator::Allocator;
19use crate::error::{Error, Result};
20
21pub trait Program: Send + Sync {
37 unsafe fn execute(
53 &self,
54 buffers: &[*mut u8],
55 vals: &[i64],
56 global_size: Option<[usize; 3]>,
57 local_size: Option<[usize; 3]>,
58 ) -> Result<()>;
59
60 fn name(&self) -> &str;
62}
63
64#[derive(Debug, Clone)]
74pub struct CompiledSpec {
75 pub name: String,
77
78 pub src: Option<String>,
81
82 pub bytes: Vec<u8>,
85
86 pub ast: Arc<UOp>,
88
89 pub var_names: Vec<String>,
92
93 pub global_size: [Arc<UOp>; 3],
95
96 pub local_size: Option<[Arc<UOp>; 3]>,
98
99 pub buf_count: usize,
101}
102
103impl CompiledSpec {
104 pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
106 Self {
107 name,
108 src: Some(src),
109 bytes: Vec::new(),
110 ast,
111 var_names: Vec::new(),
112 global_size: default_launch_size(),
113 local_size: Some(default_launch_size()),
114 buf_count,
115 }
116 }
117
118 pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
120 Self {
121 name,
122 src: None,
123 bytes,
124 ast,
125 var_names: Vec::new(),
126 global_size: default_launch_size(),
127 local_size: Some(default_launch_size()),
128 buf_count: 0,
129 }
130 }
131
132 pub fn from_source_with_sizes(
134 name: String,
135 src: String,
136 ast: Arc<UOp>,
137 global_size: [usize; 3],
138 local_size: Option<[usize; 3]>,
139 buf_count: usize,
140 ) -> Self {
141 Self {
142 name,
143 src: Some(src),
144 bytes: Vec::new(),
145 ast,
146 var_names: Vec::new(),
147 global_size: concrete_launch_size(global_size),
148 local_size: local_size.map(concrete_launch_size),
149 buf_count,
150 }
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub struct ConcreteLaunchDims {
157 pub global_size: [usize; 3],
158 pub local_size: Option<[usize; 3]>,
159}
160
161fn default_launch_size() -> [Arc<UOp>; 3] {
162 [UOp::index_const(1), UOp::index_const(1), UOp::index_const(1)]
163}
164
165fn concrete_launch_size(size: [usize; 3]) -> [Arc<UOp>; 3] {
166 [UOp::index_const(size[0] as i64), UOp::index_const(size[1] as i64), UOp::index_const(size[2] as i64)]
167}
168
169fn const_value_to_i64(value: ConstValue) -> Result<i64> {
170 match value {
171 ConstValue::Int(v) => Ok(v),
172 ConstValue::UInt(v) => i64::try_from(v)
173 .map_err(|_| Error::Runtime { message: format!("launch-size constant {v} does not fit i64") }),
174 ConstValue::Bool(v) => Ok(i64::from(v)),
175 ConstValue::Float(v) => {
176 Err(Error::Runtime { message: format!("launch-size expression must be integer, got float constant {v}") })
177 }
178 }
179}
180
181fn validate_var_bound(name: &str, value: i64, min_val: i64, max_val: i64) -> Result<()> {
182 if value < min_val || value > max_val {
183 return Err(Error::Runtime {
184 message: format!("variable {name}={value} is outside bounds [{min_val}, {max_val}]"),
185 });
186 }
187 Ok(())
188}
189
190fn checked_launch_binary(op: BinaryOp, lhs: i64, rhs: i64) -> Result<i64> {
191 let value = match op {
192 BinaryOp::Add => lhs.checked_add(rhs),
193 BinaryOp::Sub => lhs.checked_sub(rhs),
194 BinaryOp::Mul => lhs.checked_mul(rhs),
195 BinaryOp::Idiv => (rhs != 0).then(|| lhs.checked_div(rhs)).flatten(),
196 BinaryOp::Mod => (rhs != 0).then(|| lhs.checked_rem(rhs)).flatten(),
197 BinaryOp::Max => Some(lhs.max(rhs)),
198 _ => {
199 return Err(Error::Runtime { message: format!("unsupported binary op in launch-size expression: {op:?}") });
200 }
201 };
202
203 value.ok_or_else(|| Error::Runtime { message: format!("invalid launch-size arithmetic: {lhs} {op:?} {rhs}") })
204}
205
206fn eval_launch_expr(expr: &Arc<UOp>, vars: &HashMap<&str, i64>) -> Result<i64> {
207 match expr.op() {
208 Op::Const(value) => const_value_to_i64(value.0),
209 Op::DefineVar { name, min_val, max_val } => {
210 let value = vars.get(name.as_str()).copied().ok_or_else(|| Error::Runtime {
211 message: format!("missing runtime value for launch-size variable {name}"),
212 })?;
213 validate_var_bound(name, value, *min_val, *max_val)?;
214 Ok(value)
215 }
216 Op::Bind { var, value } => {
217 let bound = eval_launch_expr(value, vars)?;
218 if let Op::DefineVar { name, min_val, max_val } = var.op() {
219 validate_var_bound(name, bound, *min_val, *max_val)?;
220 }
221 Ok(bound)
222 }
223 Op::Binary(op, lhs, rhs) => {
224 checked_launch_binary(*op, eval_launch_expr(lhs, vars)?, eval_launch_expr(rhs, vars)?)
225 }
226 Op::Unary(UnaryOp::Neg, src) => eval_launch_expr(src, vars)?
227 .checked_neg()
228 .ok_or_else(|| Error::Runtime { message: "invalid launch-size negation overflow".to_string() }),
229 Op::Unary(UnaryOp::Abs, src) => eval_launch_expr(src, vars)?
230 .checked_abs()
231 .ok_or_else(|| Error::Runtime { message: "invalid launch-size abs overflow".to_string() }),
232 Op::Cast { src, .. } | Op::BitCast { src, .. } | Op::After { passthrough: src, .. } => {
233 eval_launch_expr(src, vars)
234 }
235 other => Err(Error::Runtime { message: format!("unsupported launch-size expression op: {other:?}") }),
236 }
237}
238
239fn eval_launch_size(size: &[Arc<UOp>; 3], vars: &HashMap<&str, i64>) -> Result<[usize; 3]> {
240 let mut out = [1usize; 3];
241 for (idx, expr) in size.iter().enumerate() {
242 let value = eval_launch_expr(expr, vars)?;
243 if value <= 0 {
244 return Err(Error::Runtime {
245 message: format!("launch dimension {idx} evaluated to non-positive value {value}"),
246 });
247 }
248 out[idx] = usize::try_from(value).map_err(|_| Error::Runtime {
249 message: format!("launch dimension {idx} value {value} does not fit usize"),
250 })?;
251 }
252 Ok(out)
253}
254
255pub trait Compiler: Send + Sync {
263 fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
291
292 fn cache_key(&self) -> &'static str;
297}
298
299pub trait Renderer: Send + Sync {
307 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
323
324 fn device(&self) -> &DeviceSpec;
328
329 fn decompositor(&self) -> Option<svod_ir::pattern::TypedPatternMatcher<()>> {
340 None
341 }
342}
343
344pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
355
356pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
360
361pub struct Device {
376 pub device: DeviceSpec,
378
379 pub allocator: Arc<dyn Allocator>,
381
382 pub compilers: Vec<CompilerPair>,
387
388 pub renderer: Arc<dyn Renderer>,
392
393 pub compiler: Arc<dyn Compiler>,
397
398 pub runtime: RuntimeFactory,
402}
403
404impl Device {
405 pub fn new(
410 device: DeviceSpec,
411 allocator: Arc<dyn Allocator>,
412 renderer: Arc<dyn Renderer>,
413 compiler: Arc<dyn Compiler>,
414 runtime: RuntimeFactory,
415 ) -> Self {
416 let compilers = vec![(renderer.clone(), compiler.clone())];
417 Self { device, allocator, compilers, renderer, compiler, runtime }
418 }
419
420 pub fn base_device_key(&self) -> &'static str {
431 self.device.base_type()
432 }
433}
434
435#[derive(Debug, Clone)]
447pub struct ProgramSpec {
448 pub name: String,
450
451 pub src: String,
453
454 pub device: DeviceSpec,
456
457 pub ast: Arc<UOp>,
459
460 pub global_size: [Arc<UOp>; 3],
462
463 pub local_size: Option<[Arc<UOp>; 3]>,
465
466 pub vars: Vec<Variable>,
468
469 pub var_names: Vec<String>,
472
473 pub globals: Vec<usize>,
476
477 pub outs: Vec<usize>,
480
481 pub ins: Vec<usize>,
484
485 pub buf_count: usize,
487}
488
489#[derive(Debug)]
490struct DerivedProgramMetadata {
491 vars: Vec<Variable>,
492 var_names: Vec<String>,
493 globals: Vec<usize>,
494 outs: Vec<usize>,
495 ins: Vec<usize>,
496 global_size: [Arc<UOp>; 3],
497 local_size: Option<[Arc<UOp>; 3]>,
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501enum LaunchDimKind {
502 Global,
503 Local,
504 DirectGlobal,
505}
506
507impl ProgramSpec {
508 pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
510 Self {
511 name,
512 src,
513 device,
514 ast,
515 global_size: default_launch_size(),
516 local_size: Some(default_launch_size()),
517 vars: Vec::new(),
518 var_names: Vec::new(),
519 globals: Vec::new(),
520 outs: Vec::new(),
521 ins: Vec::new(),
522 buf_count: 0,
523 }
524 }
525
526 pub fn add_var(&mut self, var: Variable) {
528 self.vars.push(var);
529 }
530
531 pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
533 self.global_size = concrete_launch_size(global);
534 self.local_size = Some(concrete_launch_size(local));
535 }
536
537 pub fn set_launch_dims(&mut self, global: [Arc<UOp>; 3], local: Option<[Arc<UOp>; 3]>) {
539 self.global_size = global;
540 self.local_size = local;
541 }
542
543 pub fn launch_dims(&self, var_vals: &HashMap<&str, i64>) -> Result<ConcreteLaunchDims> {
545 Self::resolve_launch_dims(&self.global_size, self.local_size.as_ref(), var_vals)
546 }
547
548 pub fn resolve_launch_dims(
550 global_size: &[Arc<UOp>; 3],
551 local_size: Option<&[Arc<UOp>; 3]>,
552 var_vals: &HashMap<&str, i64>,
553 ) -> Result<ConcreteLaunchDims> {
554 Ok(ConcreteLaunchDims {
555 global_size: eval_launch_size(global_size, var_vals)?,
556 local_size: local_size.map(|local| eval_launch_size(local, var_vals)).transpose()?,
557 })
558 }
559
560 pub fn set_var_names(&mut self, var_names: Vec<String>) {
562 self.var_names = var_names;
563 }
564
565 pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
567 self.globals = globals;
568 self.outs = outs;
569 self.ins = ins;
570 }
571
572 pub fn apply_derived_metadata_from_ast(&mut self) {
577 let derived = Self::derive_metadata_from_sink(&self.ast);
578 self.globals = derived.globals;
579 self.outs = derived.outs;
580 self.ins = derived.ins;
581 if self.vars.is_empty() {
582 self.vars = derived.vars;
583 }
584 if self.var_names.is_empty() {
585 self.var_names = derived.var_names;
586 }
587 if self.buf_count == 0 {
588 self.buf_count = self.globals.len();
589 }
590 self.global_size = derived.global_size;
591 self.local_size = derived.local_size;
592 }
593
594 fn special_launch_axis(name: &str) -> Option<(LaunchDimKind, usize)> {
595 let kind = match name.chars().next()? {
596 'g' => LaunchDimKind::Global,
597 'l' => LaunchDimKind::Local,
598 'i' => LaunchDimKind::DirectGlobal,
599 _ => return None,
600 };
601 let suffix_start = name.rfind(|ch: char| !ch.is_ascii_digit()).map(|idx| idx + 1).unwrap_or(0);
602 if suffix_start == name.len() {
603 return None;
604 }
605 let axis = name[suffix_start..].parse::<usize>().ok()?;
606 (axis < 3).then_some((kind, axis))
607 }
608
609 fn is_const_one(uop: &Arc<UOp>) -> bool {
610 matches!(uop.op(), Op::Const(value) if matches!(value.0, ConstValue::Int(1) | ConstValue::UInt(1)))
611 }
612
613 fn has_non_default_launch_dims(&self) -> bool {
614 !self.global_size.iter().all(Self::is_const_one)
615 || !matches!(&self.local_size, Some(local) if local.iter().all(Self::is_const_one))
616 }
617
618 fn extract_param_slot_from_index(index: &Arc<UOp>) -> Option<usize> {
619 fn slot_from_buffer(buffer: &Arc<UOp>) -> Option<usize> {
620 if let Op::Param { slot, device: None, .. } = buffer.op() { Some(*slot) } else { None }
621 }
622
623 match index.op() {
624 Op::Index { buffer, .. } => slot_from_buffer(buffer),
625 Op::Cast { src, .. } => match src.op() {
626 Op::Index { buffer, .. } => slot_from_buffer(buffer),
627 _ => None,
628 },
629 _ => None,
630 }
631 }
632
633 fn derive_metadata_from_sink(sink: &Arc<UOp>) -> DerivedProgramMetadata {
634 let mut vars = Vec::new();
635 let mut globals = Vec::new();
636 let mut outs = Vec::new();
637 let mut ins = Vec::new();
638 let mut global_size = default_launch_size();
639 let mut local_size = Some(default_launch_size());
640
641 for node in sink.toposort() {
642 match node.op() {
643 Op::DefineVar { name, min_val, max_val } => {
644 vars.push(Variable::new(name.clone(), *min_val, *max_val));
645 if name == "core_id" {
646 global_size[0] = UOp::index_const(max_val.saturating_add(1));
647 }
648 }
649 Op::Param { slot, device: None, .. } => {
650 globals.push(*slot);
651 }
652 Op::Special { end, name } => {
653 if let Some((kind, axis)) = Self::special_launch_axis(name) {
654 match kind {
655 LaunchDimKind::Global => global_size[axis] = end.clone(),
656 LaunchDimKind::Local => {
657 local_size.get_or_insert_with(default_launch_size)[axis] = end.clone()
658 }
659 LaunchDimKind::DirectGlobal => {
660 global_size[axis] = end.clone();
661 local_size = None;
662 }
663 }
664 }
665 }
666 Op::Store { index, .. } => {
667 if let Some(slot) = Self::extract_param_slot_from_index(index) {
668 outs.push(slot);
669 }
670 }
671 Op::Load { index, .. } => {
672 if let Some(slot) = Self::extract_param_slot_from_index(index) {
673 ins.push(slot);
674 }
675 }
676 _ => {}
677 }
678 }
679
680 vars.sort_by(|a, b| a.name.cmp(&b.name));
681 vars.dedup_by(|a, b| a.name == b.name);
682 let var_names = vars.iter().map(|v| v.name.clone()).collect();
683
684 globals.sort_unstable();
685 globals.dedup();
686
687 outs.sort_unstable();
688 outs.dedup();
689
690 ins.sort_unstable();
691 ins.dedup();
692
693 DerivedProgramMetadata { vars, var_names, globals, outs, ins, global_size, local_size }
694 }
695
696 pub fn from_uop(program: &Arc<UOp>) -> Result<Self> {
700 let Op::Program { sink, device, linear, source, binary } = program.op() else {
701 return Err(Error::Runtime { message: format!("expected PROGRAM op, got {:?}", program.op()) });
702 };
703
704 if !matches!(sink.op(), Op::Sink { .. }) {
705 return Err(Error::Runtime { message: format!("PROGRAM sink stage must be SINK op, got {:?}", sink.op()) });
706 }
707
708 let device_spec = match device.op() {
709 Op::Device(spec) => spec.clone(),
710 _ => {
711 return Err(Error::Runtime {
712 message: format!("PROGRAM device must be DEVICE op, got {:?}", device.op()),
713 });
714 }
715 };
716
717 let linear =
718 linear.as_ref().ok_or_else(|| Error::Runtime { message: "PROGRAM missing LINEAR stage".to_string() })?;
719 if !matches!(linear.op(), Op::Linear { .. }) {
720 return Err(Error::Runtime {
721 message: format!("PROGRAM linear stage must be LINEAR op, got {:?}", linear.op()),
722 });
723 }
724
725 let source =
726 source.as_ref().ok_or_else(|| Error::Runtime { message: "PROGRAM missing SOURCE stage".to_string() })?;
727 let source_code = match source.op() {
728 Op::Source { code } => code.clone(),
729 _ => {
730 return Err(Error::Runtime {
731 message: format!("PROGRAM source stage must be SOURCE op, got {:?}", source.op()),
732 });
733 }
734 };
735
736 if let Some(binary) = binary
737 && !matches!(binary.op(), Op::ProgramBinary { .. })
738 {
739 return Err(Error::Runtime {
740 message: format!("PROGRAM binary stage must be ProgramBinary op, got {:?}", binary.op()),
741 });
742 }
743
744 let derived = Self::derive_metadata_from_sink(sink);
745 let meta = program.metadata::<ProgramSpec>();
746
747 let name = meta.as_ref().map(|m| m.name.clone()).unwrap_or_else(|| "kernel".to_string());
748
749 let mut spec = Self::new(name, source_code, device_spec, sink.clone());
750 spec.vars = meta.as_ref().map(|m| m.vars.clone()).filter(|vars| !vars.is_empty()).unwrap_or(derived.vars);
751 spec.var_names =
752 meta.as_ref().map(|m| m.var_names.clone()).filter(|names| !names.is_empty()).unwrap_or(derived.var_names);
753 spec.globals =
754 meta.as_ref().map(|m| m.globals.clone()).filter(|globals| !globals.is_empty()).unwrap_or(derived.globals);
755 spec.outs = meta.as_ref().map(|m| m.outs.clone()).filter(|outs| !outs.is_empty()).unwrap_or(derived.outs);
756 spec.ins = meta.as_ref().map(|m| m.ins.clone()).filter(|ins| !ins.is_empty()).unwrap_or(derived.ins);
757 spec.buf_count = meta.as_ref().map(|m| m.buf_count).filter(|count| *count > 0).unwrap_or(spec.globals.len());
758 let meta_launch = meta.as_ref().filter(|m| m.has_non_default_launch_dims());
759 spec.global_size = meta_launch.map(|m| m.global_size.clone()).unwrap_or(derived.global_size);
760 spec.local_size = meta_launch.map(|m| m.local_size.clone()).unwrap_or(derived.local_size);
761
762 Ok(spec)
763 }
764}
765
766#[derive(Debug, Clone)]
774pub struct Variable {
775 pub name: String,
777
778 pub min: i64,
780
781 pub max: i64,
783}
784
785impl Variable {
786 pub fn new(name: String, min: i64, max: i64) -> Self {
788 Self { name, min, max }
789 }
790}