Skip to main content

svod_codegen/llvm/text/
mod.rs

1//! Text-based LLVM IR code generation (main entry point).
2//!
3//! This module generates LLVM IR as plain strings using `format!` macros,
4//! following Tinygrad's approach in `renderer/llvmir.py`.
5//!
6//! # Kernel Signature
7//!
8//! Generates a single function with direct typed parameters and `noalias align 32`
9//! buffer annotations:
10//! ```llvm
11//! define void @kernel(ptr noalias align 32 %buf0, ..., i32 %N) #0 { ... }
12//! ```
13
14use std::sync::Arc;
15
16use svod_ir::pattern::TypedPatternMatcher;
17use svod_ir::{Op, prelude::*};
18
19use crate::common::is_output_buffer;
20use crate::llvm::common::{RenderContext, ldt};
21use crate::llvm::cpu::{reduce_identity, render_uop};
22use crate::{BufferArg, Error, RenderedKernel, Renderer, Result};
23
24/// Text-based LLVM IR renderer.
25///
26/// Generates LLVM IR as strings, suitable for compilation via external clang.
27/// Produces a single function with direct typed parameters.
28pub struct LlvmTextRenderer;
29
30impl LlvmTextRenderer {
31    pub fn new() -> Self {
32        Self
33    }
34}
35
36impl Default for LlvmTextRenderer {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl Renderer for LlvmTextRenderer {
43    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
44        let kernel_name = name.unwrap_or("kernel");
45
46        let nodes: Vec<Arc<UOp>> = match uop.op() {
47            Op::Linear { ops } => ops.iter().cloned().collect(),
48            other => {
49                return Err(Error::InvalidGraph {
50                    reason: format!("LLVM text renderer expects LINEAR input, got {other:?}"),
51                });
52            }
53        };
54
55        for (i, node) in nodes.iter().enumerate() {
56            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "linearized node");
57            if matches!(node.op(), Op::Custom { .. } | Op::CustomI { .. }) {
58                return Err(Error::InvalidGraph {
59                    reason: format!(
60                        "LLVM backend does not support CUSTOM/CUSTOMI templates (op id {}); use C backend for custom templates",
61                        node.id
62                    ),
63                });
64            }
65        }
66
67        let mut ctx = RenderContext::new();
68        let mut kernel: Vec<String> = Vec::new();
69        let mut buffer_args: Vec<BufferArg> = Vec::new();
70        let mut var_names: Vec<String> = Vec::new();
71
72        let mut buffers: Vec<Arc<UOp>> = Vec::new();
73        let mut variables: Vec<Arc<UOp>> = Vec::new();
74
75        for node in &nodes {
76            match node.op() {
77                Op::Param { device: None, .. } => {
78                    buffers.push(node.clone());
79                }
80                Op::DefineVar { .. } => {
81                    variables.push(node.clone());
82                }
83                _ => {}
84            }
85        }
86
87        buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
88
89        for (i, buf) in buffers.iter().enumerate() {
90            if let Op::Param { slot, device: None, .. } = buf.op() {
91                let is_output = is_output_buffer(buf, &nodes);
92                buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
93            }
94        }
95
96        for var in &variables {
97            if let Op::DefineVar { name, .. } = var.op() {
98                var_names.push(name.clone());
99            }
100        }
101        // -- Build function parameters --
102        let mut inner_params: Vec<String> = Vec::new();
103
104        // Buffer pointer parameters
105        for (i, buf) in buffers.iter().enumerate() {
106            inner_params.push(format!("ptr noalias align 32 %buf{i}"));
107            ctx.register(buf.id, format!("%buf{i}"));
108        }
109
110        // Variable parameters
111        for var in &variables {
112            let var_base_name =
113                if let Op::DefineVar { name, .. } = var.op() { name.clone() } else { "var".to_string() };
114            let var_dtype = var.dtype();
115            let var_dtype_str = ldt(&var_dtype);
116            inner_params.push(format!("{var_dtype_str} %{var_base_name}"));
117            ctx.register(var.id, format!("%{var_base_name}"));
118        }
119
120        // -- Build function body --
121        kernel.push("  ; Reduction accumulators".to_string());
122        for node in &nodes {
123            if let Op::Reduce { reduce_op, .. } = node.op() {
124                let dtype = ldt(&node.dtype());
125                let identity = reduce_identity(*reduce_op, &node.dtype());
126                let acc_name = format!("%reduce_{}", node.id);
127                kernel.push(format!("  {acc_name} = alloca {dtype}"));
128                kernel.push(format!("  store {dtype} {identity}, ptr {acc_name}"));
129                ctx.register(node.id, acc_name);
130            }
131        }
132
133        // WMMA scratch buffers — one alloca + ptrtoint per (A, B, C) operand.
134        // Allocas placed in the entry block so LLVM's mem2reg can promote them
135        // to vector registers across loop iterations. Without this, the WMMA
136        // accumulator is materialized to memory every K iteration.
137        let wmma_count = nodes.iter().filter(|n| matches!(n.op(), Op::Wmma { .. })).count();
138        if wmma_count > 0 {
139            kernel.push("  ; WMMA AMX scratch buffers".to_string());
140            for node in &nodes {
141                if let Op::Wmma { a, b, c, .. } = node.op() {
142                    for (i, src) in [a, b, c].iter().enumerate() {
143                        let dtype = ldt(&src.dtype());
144                        let base = format!("%wmma_{}_amx{}", node.id, i);
145                        let ptr_name = format!("%wmma_{}_ptr_amx{}", node.id, i);
146                        let align = src.dtype().bytes();
147                        kernel.push(format!("  {base} = alloca {dtype}, align {align}"));
148                        kernel.push(format!("  {ptr_name} = ptrtoint ptr {base} to i64"));
149                    }
150                }
151            }
152        }
153        kernel.push("".to_string());
154
155        for node in &nodes {
156            match node.op() {
157                Op::Const(cv) => {
158                    if node.dtype().vcount() > 1 {
159                        // Vector-typed CONST → splat via insertelement +
160                        // shufflevector.
161                        //
162                        // Invariant after this pre-pass: any UOp with a vector
163                        // dtype either has a true vector value (load, ALU
164                        // result, vectorize) or — for vector CONSTs — gets a
165                        // named splat value emitted in the entry block.
166                        let scalar_dtype = node.dtype().scalar_dtype();
167                        let scalar_lit = crate::llvm::common::lconst(&cv.0, &scalar_dtype);
168                        let scalar_ty = ldt(&scalar_dtype);
169                        let count = node.dtype().vcount();
170                        let dst = ctx.name(node);
171                        kernel.push(format!(
172                            "  {dst}_splat0 = insertelement <1 x {scalar_ty}> poison, {scalar_ty} {scalar_lit}, i32 0"
173                        ));
174                        kernel.push(format!(
175                            "  {dst} = shufflevector <1 x {scalar_ty}> {dst}_splat0, \
176                             <1 x {scalar_ty}> poison, <{count} x i32> zeroinitializer"
177                        ));
178                    } else {
179                        let val = crate::llvm::common::lconst(&cv.0, &node.dtype());
180                        ctx.register(node.id, val);
181                    }
182                }
183                Op::VConst { values } => {
184                    // Per-lane vector CONST → VECTORIZE chain of scalar
185                    // CONSTs, emitted as a sequence of insertelements.
186                    let scalar_dtype = node.dtype().scalar_dtype();
187                    let scalar_ty = ldt(&scalar_dtype);
188                    let vec_ty = ldt(&node.dtype());
189                    let dst = ctx.name(node);
190                    let mut prev = "poison".to_string();
191                    for (i, cv) in values.iter().enumerate() {
192                        let scalar_lit = crate::llvm::common::lconst(cv, &scalar_dtype);
193                        let next = if i + 1 == values.len() { dst.clone() } else { format!("{dst}_e{i}") };
194                        kernel.push(format!(
195                            "  {next} = insertelement {vec_ty} {prev}, {scalar_ty} {scalar_lit}, i32 {i}"
196                        ));
197                        prev = next;
198                    }
199                }
200                _ => {}
201            }
202        }
203
204        for node in &nodes {
205            if let Op::Range { axis_id, .. } = node.op() {
206                let name = format!("%r{}", axis_id.value());
207                ctx.register(node.id, name);
208            }
209        }
210
211        for node in &nodes {
212            if matches!(node.op(), Op::Noop | Op::Group { .. }) {
213                ctx.register(node.id, String::new());
214                continue;
215            }
216            render_uop(node, &mut ctx, &mut kernel);
217            if let Some(err) = ctx.take_error() {
218                return Err(err);
219            }
220        }
221
222        kernel.push("  ret void".to_string());
223
224        let ir = format!(
225            r#"; ModuleID = '{kernel_name}'
226source_filename = "{kernel_name}"
227
228{intrinsics}
229
230define void @{kernel_name}({inner_params}) #0 {{
231entry:
232{inner_body}
233}}
234
235attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
236"#,
237            intrinsics = generate_intrinsic_declarations(&kernel),
238            inner_params = inner_params.join(", "),
239            inner_body = kernel.join("\n"),
240        );
241
242        tracing::debug!(generated_code = ir, "llvm codegen: final generated code");
243
244        let mut result = RenderedKernel::new(ir, kernel_name.to_string());
245        result.buffer_args = buffer_args;
246        result.var_names = var_names;
247
248        Ok(result)
249    }
250
251    fn backend_name(&self) -> &str {
252        "llvm-text"
253    }
254
255    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
256        None
257    }
258}
259
260fn mangle_type(llvm_type: &str) -> String {
261    match llvm_type {
262        "float" => "f32".to_string(),
263        "double" => "f64".to_string(),
264        "half" => "f16".to_string(),
265        "i8" => "i8".to_string(),
266        "i16" => "i16".to_string(),
267        "i32" => "i32".to_string(),
268        "i64" => "i64".to_string(),
269        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
270            let inner = &llvm_type[1..llvm_type.len() - 1];
271            let parts: Vec<&str> = inner.split(" x ").collect();
272            if parts.len() == 2 {
273                let count = parts[0].trim();
274                let base = mangle_type(parts[1].trim());
275                format!("v{count}{base}")
276            } else {
277                llvm_type.to_string()
278            }
279        }
280        _ => llvm_type.to_string(),
281    }
282}
283
284fn generate_intrinsic_declarations(kernel: &[String]) -> String {
285    let mut decls = Vec::new();
286    let kernel_str = kernel.join("\n");
287
288    for intrinsic in &[
289        "sqrt", "exp", "exp2", "log", "log2", "sin", "cos", "pow", "fabs", "floor", "ceil", "trunc", "round", "maxnum",
290        "minnum", "fmuladd", "erf",
291    ] {
292        for llvm_type in
293            &["float", "double", "half", "<2 x float>", "<4 x float>", "<8 x float>", "<2 x double>", "<4 x double>"]
294        {
295            let mangled = mangle_type(llvm_type);
296            let pattern = format!("@llvm.{intrinsic}.{mangled}");
297            if kernel_str.contains(&pattern) {
298                let decl = match *intrinsic {
299                    "fmuladd" => format!(
300                        "declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type}, {llvm_type})"
301                    ),
302                    "pow" | "maxnum" | "minnum" => {
303                        format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type})")
304                    }
305                    _ => format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type})"),
306                };
307                decls.push(decl);
308            }
309        }
310    }
311
312    for bits in &["i8", "i16", "i32", "i64"] {
313        let pattern = format!("@llvm.abs.{bits}");
314        if kernel_str.contains(&pattern) {
315            decls.push(format!("declare {bits} @llvm.abs.{bits}({bits}, i1)"));
316        }
317    }
318
319    decls.join("\n")
320}
321
322pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
323    let renderer = LlvmTextRenderer::new();
324    renderer.render(uop, name)
325}
326
327#[cfg(test)]
328#[path = "../../test/unit/llvm_text.rs"]
329mod tests;