Skip to main content

svod_codegen/c/
mod.rs

1//! C source code generation backend.
2//!
3//! Generates C source code from linearized UOp IR, suitable for compilation
4//! with `clang -shared -O2` and loading via `dlopen`.
5//!
6//! # Kernel Signature
7//!
8//! Emits a single function with typed `restrict` pointer params and const variable params:
9//!
10//! ```c
11//! void kernel(float* restrict data0, const int N) { /* body */ }
12//! ```
13
14mod amx;
15pub mod ops;
16pub mod types;
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use svod_ir::pattern::TypedPatternMatcher;
22use svod_ir::{Op, prelude::*};
23
24use crate::common::{is_output_buffer, validate_custom_template_strict};
25use crate::{BufferArg, Error, RenderedKernel, Result};
26
27use self::ops::{CContext, count_references, render_uop};
28use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
29
30/// C source code renderer for CPU execution via clang.
31pub struct CRenderer;
32
33impl CRenderer {
34    pub fn new() -> Self {
35        Self
36    }
37}
38
39impl Default for CRenderer {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl crate::Renderer for CRenderer {
46    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
47        let kernel_name = name.unwrap_or("kernel");
48
49        let nodes: Vec<Arc<UOp>> = match uop.op() {
50            Op::Linear { ops } => ops.iter().cloned().collect(),
51            other => {
52                return Err(Error::InvalidGraph { reason: format!("C renderer expects LINEAR input, got {other:?}") });
53            }
54        };
55
56        for (i, node) in nodes.iter().enumerate() {
57            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
58            match node.op() {
59                Op::Custom { deps, code } | Op::CustomI { deps, code } => {
60                    validate_custom_template_strict(code, deps.len())?;
61                }
62                _ => {}
63            }
64        }
65
66        // Collect buffers and variables from linearized stream
67        let mut buffers: Vec<Arc<UOp>> = Vec::new();
68        let mut variables: Vec<Arc<UOp>> = Vec::new();
69
70        for node in &nodes {
71            match node.op() {
72                Op::Param { device: None, .. } => buffers.push(node.clone()),
73                Op::DefineVar { .. } => variables.push(node.clone()),
74                _ => {}
75            }
76        }
77
78        buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
79
80        // Build buffer args metadata
81        let mut buffer_args: Vec<BufferArg> = Vec::new();
82        for (i, buf) in buffers.iter().enumerate() {
83            if let Op::Param { slot, device: None, .. } = buf.op() {
84                let is_output = is_output_buffer(buf, &nodes);
85                buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
86            }
87        }
88
89        // Build var_names
90        let mut var_names: Vec<String> = Vec::new();
91        for var in &variables {
92            if let Op::DefineVar { name, .. } = var.op() {
93                var_names.push(name.clone());
94            }
95        }
96        // Count references for SSA inlining decisions
97        let ref_counts = count_references(&nodes);
98        let scope_escaping = find_scope_escaping_vars(&nodes, &ref_counts);
99        let mut ctx = CContext::new(ref_counts, scope_escaping);
100
101        // === Build C source ===
102        let mut code_lines: Vec<String> = Vec::new();
103
104        // Includes
105        code_lines.push("#include <stdbool.h>".to_string());
106        code_lines.push("".to_string());
107
108        // Vector typedefs
109        let typedefs = collect_vector_typedefs(&nodes);
110        for td in &typedefs {
111            code_lines.push(td.clone());
112        }
113        if !typedefs.is_empty() {
114            code_lines.push("".to_string());
115        }
116
117        // WMMA (AMX) defines and static functions
118        let wmma_defines = amx::collect_wmma_defines(&nodes);
119        for def in &wmma_defines {
120            code_lines.push(def.clone());
121        }
122        if !wmma_defines.is_empty() {
123            code_lines.push("".to_string());
124        }
125
126        // Build typed function params
127        let mut params: Vec<String> = Vec::new();
128
129        // Buffer parameters
130        for (i, buf) in buffers.iter().enumerate() {
131            let buf_dtype = buf.dtype();
132            let elem_type = match &buf_dtype {
133                DType::Ptr { base, .. } => c_dtype(base),
134                _ => c_dtype(&buf_dtype),
135            };
136            let name = format!("data{i}");
137            params.push(format!("{elem_type}* restrict {name}"));
138            ctx.register(buf.id, name);
139        }
140
141        // Variable parameters
142        for var in &variables {
143            if let Op::DefineVar { name, .. } = var.op() {
144                let var_dtype = &var.dtype();
145                let c_type = c_dtype(var_dtype);
146                params.push(format!("const {c_type} {name}"));
147                ctx.register(var.id, name.clone());
148            }
149        }
150
151        // Function signature
152        code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
153
154        // Local memory allocations (stack arrays on CPU)
155        for node in &nodes {
156            if let Op::DefineLocal(id) = node.op() {
157                let (base, size) = match node.dtype() {
158                    DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
159                    other => (c_dtype(&other), 1),
160                };
161                let name = format!("local{id}");
162                code_lines.push(format!("  {base} {name}[{size}];"));
163                ctx.register(node.id, name);
164            }
165        }
166
167        code_lines.push("".to_string());
168
169        // Reduction accumulator declarations (need to be in outer scope)
170        for node in &nodes {
171            if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
172                if ranges.is_empty() {
173                    continue;
174                }
175                let dtype = &node.dtype();
176                let c_type = c_dtype(dtype);
177                let identity = c_reduce_identity(*reduce_op, dtype);
178                let acc_name = format!("acc{}", node.id);
179                code_lines.push(format!("  {c_type} {acc_name} = {identity};"));
180                // Pre-register so the ops.rs render_uop finds it
181                ctx.register(node.id, acc_name);
182            }
183        }
184
185        // Register constants
186        for node in &nodes {
187            match node.op() {
188                Op::Const(cv) => {
189                    let val = c_const(&cv.0, &node.dtype());
190                    ctx.register(node.id, val);
191                }
192                Op::VConst { values } => {
193                    let val = c_vconst(values, &node.dtype());
194                    ctx.register(node.id, val);
195                }
196                _ => {}
197            }
198        }
199
200        // Pre-register range variable names
201        for node in &nodes {
202            if let Op::Range { axis_id, .. } = node.op() {
203                let name = format!("ridx{}", axis_id.value());
204                ctx.register(node.id, name);
205            }
206        }
207
208        // Render all instructions
209        // Skip NOOP and GROUP — they are structural no-ops (Tinygrad cstyle.py:175)
210        let mut kernel_body: Vec<String> = Vec::new();
211        for node in &nodes {
212            if matches!(node.op(), Op::Noop | Op::Group { .. }) {
213                // Register with empty string so downstream UNROLL/CONTRACT can alias them.
214                // Matches LLVM backend behavior — these are structural no-ops.
215                ctx.register(node.id, String::new());
216                continue;
217            }
218            render_uop(node, &mut ctx, &mut kernel_body);
219            if let Some(err) = ctx.take_error() {
220                return Err(err);
221            }
222        }
223
224        // Emit hoisted declarations for scope-escaping variables (before kernel body)
225        if !ctx.hoisted_declarations.is_empty() {
226            code_lines.append(&mut ctx.hoisted_declarations);
227        }
228        code_lines.extend(kernel_body);
229        code_lines.push("}".to_string());
230        code_lines.push("".to_string());
231
232        let code = code_lines.join("\n");
233
234        tracing::debug!(generated_c = code, "c codegen: final generated code");
235
236        let mut result = RenderedKernel::new(code, kernel_name.to_string());
237        result.buffer_args = buffer_args;
238        result.var_names = var_names;
239
240        Ok(result)
241    }
242
243    fn backend_name(&self) -> &str {
244        "clang"
245    }
246
247    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
248        // C uses __builtin_ math functions (sqrt, exp, sin, etc.) — no decomposition needed.
249        // Threefry is handled by XOR in render.
250        None
251    }
252}
253
254/// Find variables that escape their declaration scope.
255///
256/// Walks the linearized instruction list tracking scope depth. A variable "escapes"
257/// if it's defined at a deeper scope than where it's used. Returns the set of UOp IDs
258/// that need function-scope declarations to avoid "use of undeclared identifier" errors.
259///
260/// This handles the case where pm_decomp creates sibling ENDs that share sub-DAG nodes.
261/// The linearizer places the shared node inside one loop, but another consumer is outside.
262fn find_scope_escaping_vars(nodes: &[Arc<UOp>], ref_counts: &HashMap<u64, usize>) -> HashSet<u64> {
263    let mut depth = 0usize;
264    let mut def_depth: HashMap<u64, usize> = HashMap::new();
265    let mut min_use_depth: HashMap<u64, usize> = HashMap::new();
266
267    for node in nodes {
268        // Track scope depth changes
269        match node.op() {
270            Op::Range { .. } | Op::If { .. } => {
271                // Definition of this node is at current depth (before entering)
272                if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
273                    def_depth.entry(node.id).or_insert(depth);
274                }
275                // Record usages of sources at current depth
276                for src in node.op().sources() {
277                    min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
278                }
279                depth += 1;
280                continue;
281            }
282            Op::End { .. } | Op::EndIf { .. } => {
283                depth = depth.saturating_sub(1);
284            }
285            _ => {}
286        }
287
288        // Record definition depth for multi-use values
289        if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
290            def_depth.entry(node.id).or_insert(depth);
291        }
292
293        // Record minimum usage depth for all source operands
294        for src in node.op().sources() {
295            min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
296        }
297    }
298
299    // Variables where any use is at a shallower depth than definition
300    def_depth
301        .into_iter()
302        .filter(|(id, def_d)| min_use_depth.get(id).copied().unwrap_or(*def_d) < *def_d)
303        .map(|(id, _)| id)
304        .collect()
305}
306
307/// Public render function for the C backend.
308pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
309    let renderer = CRenderer::new();
310    crate::Renderer::render(&renderer, uop, name)
311}