1use std::sync::Arc;
2
3use svod_device::device::{CompiledSpec, Compiler, ProgramSpec, Renderer};
4use svod_device::{Error, Result};
5use svod_dtype::DeviceSpec;
6use svod_ir::{Op, UOp};
7use svod_schedule::linearize::line_rewrite_cleanups;
8
9type ProgramParts = (Arc<UOp>, Arc<UOp>, Option<Arc<UOp>>, Option<Arc<UOp>>, Option<Arc<UOp>>);
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ProgramTarget {
13 Linear,
14 Source,
15 Binary,
16}
17
18fn invalid_program_state(details: impl Into<String>) -> Error {
19 Error::Runtime { message: details.into() }
20}
21
22fn unpack_program(program: &Arc<UOp>) -> Result<ProgramParts> {
23 let Op::Program { sink, device, linear, source, binary } = program.op() else {
24 return Err(invalid_program_state(format!("expected PROGRAM op, got {:?}", program.op())));
25 };
26 Ok((sink.clone(), device.clone(), linear.clone(), source.clone(), binary.clone()))
27}
28
29fn validate_program_shape(program: &Arc<UOp>) -> Result<()> {
30 let (sink, device, linear, source, binary) = unpack_program(program)?;
31
32 if !matches!(sink.op(), Op::Sink { .. }) {
33 return Err(invalid_program_state(format!("PROGRAM sink must be SINK op, got {:?}", sink.op())));
34 }
35
36 if !matches!(device.op(), Op::Device(_)) {
37 return Err(invalid_program_state(format!("PROGRAM device must be DEVICE op, got {:?}", device.op())));
38 }
39
40 if let Some(linear) = &linear
41 && !matches!(linear.op(), Op::Linear { .. })
42 {
43 return Err(invalid_program_state(format!("PROGRAM linear stage must be LINEAR op, got {:?}", linear.op())));
44 }
45
46 if let Some(source) = &source
47 && !matches!(source.op(), Op::Source { .. })
48 {
49 return Err(invalid_program_state(format!("PROGRAM source stage must be SOURCE op, got {:?}", source.op())));
50 }
51
52 if let Some(binary) = &binary
53 && !matches!(binary.op(), Op::ProgramBinary { .. })
54 {
55 return Err(invalid_program_state(format!(
56 "PROGRAM binary stage must be ProgramBinary op, got {:?}",
57 binary.op()
58 )));
59 }
60
61 if source.is_some() && linear.is_none() {
62 return Err(invalid_program_state("malformed PROGRAM state: SOURCE requires LINEAR stage"));
63 }
64 if binary.is_some() && source.is_none() {
65 return Err(invalid_program_state("malformed PROGRAM state: BINARY requires SOURCE stage"));
66 }
67
68 Ok(())
69}
70
71fn preserve_program_context(new_program: Arc<UOp>, old_program: &Arc<UOp>) -> Arc<UOp> {
72 let mut out = new_program.rtag(old_program.tag().clone());
73 if let Some(meta) = old_program.metadata_raw() {
74 out = out.with_metadata_raw(meta);
75 }
76 out
77}
78
79fn rebuild_program(
80 base_program: &Arc<UOp>,
81 linear: Option<Arc<UOp>>,
82 source: Option<Arc<UOp>>,
83 binary: Option<Arc<UOp>>,
84) -> Result<Arc<UOp>> {
85 let (sink, device, _, _, _) = unpack_program(base_program)?;
86 let rebuilt = UOp::program(sink, device, linear, source, binary);
87 Ok(preserve_program_context(rebuilt, base_program))
88}
89
90pub fn program_from_sink(sink: Arc<UOp>, device: DeviceSpec) -> Arc<UOp> {
92 let sink = if matches!(sink.op(), Op::Sink { .. }) { sink } else { UOp::sink(vec![sink]) };
93 UOp::program(sink, UOp::device(device), None, None, None)
94}
95
96pub fn do_linearize(program: &Arc<UOp>) -> Result<Arc<UOp>> {
98 validate_program_shape(program)?;
99 let (sink, _device, linear, source, binary) = unpack_program(program)?;
100 if linear.is_some() {
101 return Ok(program.clone());
102 }
103
104 let linear_ops = svod_schedule::linearize_with_cfg(sink);
105 let linear_clean = line_rewrite_cleanups(linear_ops);
106 let linear_uop = UOp::linear(linear_clean.into());
107 rebuild_program(program, Some(linear_uop), source, binary)
108}
109
110pub fn do_render(program: &Arc<UOp>, renderer: &dyn Renderer, name: Option<&str>) -> Result<(Arc<UOp>, ProgramSpec)> {
112 let linearized = do_linearize(program)?;
113 let (_sink, _device, linear, source, binary) = unpack_program(&linearized)?;
114
115 if source.is_some() || binary.is_some() {
116 return Err(invalid_program_state(format!(
117 "do_render expects PROGRAM stage with LINEAR only (source=None,binary=None), got source_present={}, binary_present={}",
118 source.is_some(),
119 binary.is_some()
120 )));
121 }
122
123 let linear_uop = linear.clone().ok_or_else(|| invalid_program_state("PROGRAM has no LINEAR stage"))?;
124
125 let spec = renderer.render(&linear_uop, name)?;
126 let source_uop = UOp::source(spec.src.clone());
127 let mut rendered = rebuild_program(&linearized, linear, Some(source_uop), None)?;
128 rendered = rendered.with_metadata(spec.clone());
129 Ok((rendered, spec))
130}
131
132pub fn do_compile(program: &Arc<UOp>, compiler: &dyn Compiler) -> Result<(Arc<UOp>, CompiledSpec)> {
134 validate_program_shape(program)?;
135 let (sink, _device, linear, source, binary) = unpack_program(program)?;
136
137 if let Some(binary_uop) = binary {
138 let Op::ProgramBinary { bytes } = binary_uop.op() else {
139 return Err(invalid_program_state("PROGRAM binary stage is not a ProgramBinary UOp"));
140 };
141
142 let spec = ProgramSpec::from_uop(program)?;
143
144 let mut compiled = CompiledSpec::from_bytes(spec.name.clone(), bytes.clone(), sink);
145 if !spec.src.is_empty() {
146 compiled.src = Some(spec.src.clone());
147 }
148 compiled.var_names = spec.var_names.clone();
149 compiled.global_size = spec.global_size.clone();
150 compiled.local_size = spec.local_size.clone();
151 compiled.buf_count = spec.buf_count;
152 return Ok((program.clone(), compiled));
153 }
154
155 if source.is_none() {
156 return Err(invalid_program_state("PROGRAM has no SOURCE stage"));
157 }
158
159 let spec = ProgramSpec::from_uop(program)?;
160 if spec.src.is_empty() {
161 return Err(invalid_program_state("PROGRAM has empty SOURCE stage"));
162 }
163
164 let compiled = compiler.compile(&spec)?;
165
166 let binary_uop = UOp::binary(compiled.bytes.clone());
167 let mut compiled_program = rebuild_program(program, linear, source, Some(binary_uop))?;
168 compiled_program = compiled_program.with_metadata(spec);
169 Ok((compiled_program, compiled))
170}
171
172pub fn get_program(
174 input: &Arc<UOp>,
175 renderer: &dyn Renderer,
176 compiler: &dyn Compiler,
177 name: Option<&str>,
178 target: ProgramTarget,
179) -> Result<Arc<UOp>> {
180 let mut program = match input.op() {
181 Op::Program { .. } => {
182 validate_program_shape(input)?;
183 input.clone()
184 }
185 other => return Err(invalid_program_state(format!("expected PROGRAM input, got {other:?}"))),
186 };
187
188 if matches!(target, ProgramTarget::Linear | ProgramTarget::Source | ProgramTarget::Binary) {
189 let (_, _, linear, _, _) = unpack_program(&program)?;
190 if linear.is_none() {
191 program = do_linearize(&program)?;
192 }
193 }
194
195 if matches!(target, ProgramTarget::Source | ProgramTarget::Binary) {
196 let (_, _, _, source, _) = unpack_program(&program)?;
197 if source.is_none() {
198 let (rendered, _) = do_render(&program, renderer, name)?;
199 program = rendered;
200 }
201 }
202
203 if matches!(target, ProgramTarget::Binary) {
204 let (_, _, _, _, binary) = unpack_program(&program)?;
205 if binary.is_none() {
206 let (compiled, _) = do_compile(&program, compiler)?;
207 program = compiled;
208 }
209 }
210
211 validate_program_shape(&program)?;
212 Ok(program)
213}