Skip to main content

svod_codegen/
program_pipeline.rs

1use std::sync::Arc;
2
3use svod_device::device::{CompiledSpec, Compiler, ProgramSpec, Renderer};
4use svod_device::{Error, Result};
5use svod_dtype::DeviceSpec;
6use svod_ir::{Op, UOp};
7use svod_schedule::linearize::line_rewrite_cleanups;
8
9type ProgramParts = (Arc<UOp>, Arc<UOp>, Option<Arc<UOp>>, Option<Arc<UOp>>, Option<Arc<UOp>>);
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ProgramTarget {
13    Linear,
14    Source,
15    Binary,
16}
17
18fn invalid_program_state(details: impl Into<String>) -> Error {
19    Error::Runtime { message: details.into() }
20}
21
22fn unpack_program(program: &Arc<UOp>) -> Result<ProgramParts> {
23    let Op::Program { sink, device, linear, source, binary } = program.op() else {
24        return Err(invalid_program_state(format!("expected PROGRAM op, got {:?}", program.op())));
25    };
26    Ok((sink.clone(), device.clone(), linear.clone(), source.clone(), binary.clone()))
27}
28
29fn validate_program_shape(program: &Arc<UOp>) -> Result<()> {
30    let (sink, device, linear, source, binary) = unpack_program(program)?;
31
32    if !matches!(sink.op(), Op::Sink { .. }) {
33        return Err(invalid_program_state(format!("PROGRAM sink must be SINK op, got {:?}", sink.op())));
34    }
35
36    if !matches!(device.op(), Op::Device(_)) {
37        return Err(invalid_program_state(format!("PROGRAM device must be DEVICE op, got {:?}", device.op())));
38    }
39
40    if let Some(linear) = &linear
41        && !matches!(linear.op(), Op::Linear { .. })
42    {
43        return Err(invalid_program_state(format!("PROGRAM linear stage must be LINEAR op, got {:?}", linear.op())));
44    }
45
46    if let Some(source) = &source
47        && !matches!(source.op(), Op::Source { .. })
48    {
49        return Err(invalid_program_state(format!("PROGRAM source stage must be SOURCE op, got {:?}", source.op())));
50    }
51
52    if let Some(binary) = &binary
53        && !matches!(binary.op(), Op::ProgramBinary { .. })
54    {
55        return Err(invalid_program_state(format!(
56            "PROGRAM binary stage must be ProgramBinary op, got {:?}",
57            binary.op()
58        )));
59    }
60
61    if source.is_some() && linear.is_none() {
62        return Err(invalid_program_state("malformed PROGRAM state: SOURCE requires LINEAR stage"));
63    }
64    if binary.is_some() && source.is_none() {
65        return Err(invalid_program_state("malformed PROGRAM state: BINARY requires SOURCE stage"));
66    }
67
68    Ok(())
69}
70
71fn preserve_program_context(new_program: Arc<UOp>, old_program: &Arc<UOp>) -> Arc<UOp> {
72    let mut out = new_program.rtag(old_program.tag().clone());
73    if let Some(meta) = old_program.metadata_raw() {
74        out = out.with_metadata_raw(meta);
75    }
76    out
77}
78
79fn rebuild_program(
80    base_program: &Arc<UOp>,
81    linear: Option<Arc<UOp>>,
82    source: Option<Arc<UOp>>,
83    binary: Option<Arc<UOp>>,
84) -> Result<Arc<UOp>> {
85    let (sink, device, _, _, _) = unpack_program(base_program)?;
86    let rebuilt = UOp::program(sink, device, linear, source, binary);
87    Ok(preserve_program_context(rebuilt, base_program))
88}
89
90/// Create initial PROGRAM(sink, device) state.
91pub fn program_from_sink(sink: Arc<UOp>, device: DeviceSpec) -> Arc<UOp> {
92    let sink = if matches!(sink.op(), Op::Sink { .. }) { sink } else { UOp::sink(vec![sink]) };
93    UOp::program(sink, UOp::device(device), None, None, None)
94}
95
96/// PROGRAM -> LINEAR stage.
97pub fn do_linearize(program: &Arc<UOp>) -> Result<Arc<UOp>> {
98    validate_program_shape(program)?;
99    let (sink, _device, linear, source, binary) = unpack_program(program)?;
100    if linear.is_some() {
101        return Ok(program.clone());
102    }
103
104    let linear_ops = svod_schedule::linearize_with_cfg(sink);
105    let linear_clean = line_rewrite_cleanups(linear_ops);
106    let linear_uop = UOp::linear(linear_clean.into());
107    rebuild_program(program, Some(linear_uop), source, binary)
108}
109
110/// PROGRAM(+LINEAR) -> SOURCE stage via Renderer.
111pub fn do_render(program: &Arc<UOp>, renderer: &dyn Renderer, name: Option<&str>) -> Result<(Arc<UOp>, ProgramSpec)> {
112    let linearized = do_linearize(program)?;
113    let (_sink, _device, linear, source, binary) = unpack_program(&linearized)?;
114
115    if source.is_some() || binary.is_some() {
116        return Err(invalid_program_state(format!(
117            "do_render expects PROGRAM stage with LINEAR only (source=None,binary=None), got source_present={}, binary_present={}",
118            source.is_some(),
119            binary.is_some()
120        )));
121    }
122
123    let linear_uop = linear.clone().ok_or_else(|| invalid_program_state("PROGRAM has no LINEAR stage"))?;
124
125    let spec = renderer.render(&linear_uop, name)?;
126    let source_uop = UOp::source(spec.src.clone());
127    let mut rendered = rebuild_program(&linearized, linear, Some(source_uop), None)?;
128    rendered = rendered.with_metadata(spec.clone());
129    Ok((rendered, spec))
130}
131
132/// PROGRAM(+SOURCE) -> BINARY stage via Compiler.
133pub fn do_compile(program: &Arc<UOp>, compiler: &dyn Compiler) -> Result<(Arc<UOp>, CompiledSpec)> {
134    validate_program_shape(program)?;
135    let (sink, _device, linear, source, binary) = unpack_program(program)?;
136
137    if let Some(binary_uop) = binary {
138        let Op::ProgramBinary { bytes } = binary_uop.op() else {
139            return Err(invalid_program_state("PROGRAM binary stage is not a ProgramBinary UOp"));
140        };
141
142        let spec = ProgramSpec::from_uop(program)?;
143
144        let mut compiled = CompiledSpec::from_bytes(spec.name.clone(), bytes.clone(), sink);
145        if !spec.src.is_empty() {
146            compiled.src = Some(spec.src.clone());
147        }
148        compiled.var_names = spec.var_names.clone();
149        compiled.global_size = spec.global_size.clone();
150        compiled.local_size = spec.local_size.clone();
151        compiled.buf_count = spec.buf_count;
152        return Ok((program.clone(), compiled));
153    }
154
155    if source.is_none() {
156        return Err(invalid_program_state("PROGRAM has no SOURCE stage"));
157    }
158
159    let spec = ProgramSpec::from_uop(program)?;
160    if spec.src.is_empty() {
161        return Err(invalid_program_state("PROGRAM has empty SOURCE stage"));
162    }
163
164    let compiled = compiler.compile(&spec)?;
165
166    let binary_uop = UOp::binary(compiled.bytes.clone());
167    let mut compiled_program = rebuild_program(program, linear, source, Some(binary_uop))?;
168    compiled_program = compiled_program.with_metadata(spec);
169    Ok((compiled_program, compiled))
170}
171
172/// Progressively advance SINK/PROGRAM input to a requested PROGRAM stage.
173pub fn get_program(
174    input: &Arc<UOp>,
175    renderer: &dyn Renderer,
176    compiler: &dyn Compiler,
177    name: Option<&str>,
178    target: ProgramTarget,
179) -> Result<Arc<UOp>> {
180    let mut program = match input.op() {
181        Op::Program { .. } => {
182            validate_program_shape(input)?;
183            input.clone()
184        }
185        other => return Err(invalid_program_state(format!("expected PROGRAM input, got {other:?}"))),
186    };
187
188    if matches!(target, ProgramTarget::Linear | ProgramTarget::Source | ProgramTarget::Binary) {
189        let (_, _, linear, _, _) = unpack_program(&program)?;
190        if linear.is_none() {
191            program = do_linearize(&program)?;
192        }
193    }
194
195    if matches!(target, ProgramTarget::Source | ProgramTarget::Binary) {
196        let (_, _, _, source, _) = unpack_program(&program)?;
197        if source.is_none() {
198            let (rendered, _) = do_render(&program, renderer, name)?;
199            program = rendered;
200        }
201    }
202
203    if matches!(target, ProgramTarget::Binary) {
204        let (_, _, _, _, binary) = unpack_program(&program)?;
205        if binary.is_none() {
206            let (compiled, _) = do_compile(&program, compiler)?;
207            program = compiled;
208        }
209    }
210
211    validate_program_shape(&program)?;
212    Ok(program)
213}