Skip to main content

svod_codegen/
common.rs

1//! Common utilities shared between codegen backends.
2
3use std::sync::Arc;
4
5use svod_ir::{Op, UOp};
6
7use crate::{Error, Result};
8
9/// Check whether a buffer (PARAM/DefineGlobal) is used as a STORE target in the graph.
10pub fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
11    let buffer_id = def_global.id;
12
13    for node in nodes {
14        if let Some(buffer) = node.store_buffer() {
15            if buffer.id == buffer_id {
16                return true;
17            }
18            if let Op::Index { buffer: idx_buf, .. } = buffer.op()
19                && idx_buf.id == buffer_id
20            {
21                return true;
22            }
23        }
24    }
25    false
26}
27
28/// Collect buffer and variable parameters from a UOp graph.
29///
30/// Collects:
31/// - Buffers: PARAM, DEFINE_LOCAL operations
32/// - Variables: DEFINE_VAR operations (passed as i64 kernel params)
33///
34/// Returns (buffers, variables) sorted for deterministic function signatures.
35pub fn collect_buffers_and_vars(root: &Arc<UOp>) -> (Vec<Arc<UOp>>, Vec<Arc<UOp>>) {
36    let nodes = root.toposort();
37
38    // Collect buffers
39    let mut buffers = Vec::new();
40    for node in &nodes {
41        match node.op() {
42            Op::Buffer { .. } | Op::Param { device: None, .. } | Op::DefineLocal(_) => {
43                buffers.push(node.clone());
44            }
45            _ => {}
46        }
47    }
48
49    // Sort buffers by internal ID (matches split_kernel.rs ordering)
50    buffers.sort_by_key(|b| match b.op() {
51        Op::Param { slot, device: None, .. } => *slot as u64,
52        Op::DefineLocal(id) => (*id as u64) + (1u64 << 32),
53        Op::Buffer { .. } => b.id + (1u64 << 48),
54        _ => b.id,
55    });
56
57    // Collect DefineVar nodes - these become i64 kernel parameters
58    let mut variables = Vec::new();
59    for node in &nodes {
60        if matches!(node.op(), Op::DefineVar { .. }) {
61            variables.push(node.clone());
62        }
63    }
64
65    // Sort variables by name for deterministic function signatures
66    variables.sort_by_key(|v| if let Op::DefineVar { name, .. } = v.op() { name.clone() } else { String::new() });
67
68    (buffers, variables)
69}
70
71pub fn validate_custom_template_strict(template: &str, arg_count: usize) -> Result<()> {
72    let mut chars = template.chars().peekable();
73    let mut auto_idx = 0usize;
74    let mut saw_auto = false;
75    let mut saw_manual = false;
76
77    while let Some(ch) = chars.next() {
78        if ch == '{' {
79            if matches!(chars.peek(), Some('{')) {
80                chars.next();
81                continue;
82            }
83
84            let mut token = String::new();
85            let mut found_close = false;
86            for next in chars.by_ref() {
87                if next == '}' {
88                    found_close = true;
89                    break;
90                }
91                token.push(next);
92            }
93
94            if !found_close {
95                return Err(Error::InvalidGraph {
96                    reason: format!("custom template has unmatched '{{': {template:?}"),
97                });
98            }
99
100            let idx = if token.is_empty() {
101                saw_auto = true;
102                let i = auto_idx;
103                auto_idx += 1;
104                i
105            } else {
106                saw_manual = true;
107                token.parse::<usize>().map_err(|_| Error::InvalidGraph {
108                    reason: format!(
109                        "custom template placeholder must be empty or numeric, got {{{token}}} in {template:?}"
110                    ),
111                })?
112            };
113
114            if saw_auto && saw_manual {
115                return Err(Error::InvalidGraph {
116                    reason: format!("custom template mixes automatic {{}} and manual {{N}} placeholders: {template:?}"),
117                });
118            }
119
120            if idx >= arg_count {
121                return Err(Error::InvalidGraph {
122                    reason: format!(
123                        "custom template placeholder index {idx} out of bounds (args={arg_count}) in {template:?}"
124                    ),
125                });
126            }
127        } else if ch == '}' {
128            if matches!(chars.peek(), Some('}')) {
129                chars.next();
130            } else {
131                return Err(Error::InvalidGraph {
132                    reason: format!("custom template has unmatched '}}': {template:?}"),
133                });
134            }
135        }
136    }
137
138    Ok(())
139}
140
141pub fn format_custom_template_strict(template: &str, args: &[String]) -> Result<String> {
142    validate_custom_template_strict(template, args.len())?;
143
144    let mut out = String::new();
145    let mut chars = template.chars().peekable();
146    let mut auto_idx = 0usize;
147
148    while let Some(ch) = chars.next() {
149        if ch == '{' {
150            if matches!(chars.peek(), Some('{')) {
151                chars.next();
152                out.push('{');
153                continue;
154            }
155
156            let mut token = String::new();
157            for next in chars.by_ref() {
158                if next == '}' {
159                    break;
160                }
161                token.push(next);
162            }
163
164            let idx = if token.is_empty() {
165                let i = auto_idx;
166                auto_idx += 1;
167                i
168            } else {
169                token.parse::<usize>().expect("placeholder token validated")
170            };
171
172            out.push_str(&args[idx]);
173        } else if ch == '}' {
174            chars.next();
175            out.push('}');
176        } else {
177            out.push(ch);
178        }
179    }
180
181    Ok(out)
182}