Skip to main content

svod_codegen/llvm/cpu/
ops.rs

1//! CPU-specific LLVM IR operation rendering.
2//!
3//! Generates LLVM IR strings for individual UOp operations on CPU.
4//! Based on Tinygrad's PatternMatcher templates in `llvmir.py`.
5
6use std::sync::Arc;
7
8use svod_dtype::DType;
9use svod_ir::{BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13/// Extract a scalar `ptr` from a vectorized `<N x ptr>` via `extractelement ... i32 0`.
14///
15/// When the devectorize pipeline doesn't fully eliminate vectorized PARAM pointers
16/// (see `no_vectorized_buf` / `no_vectorized_index` which only target DEFINE_LOCAL/DEFINE_REG),
17/// the GEP result can be `<N x ptr>`. All elements are identical (broadcast of the same buffer
18/// pointer), so extracting element 0 yields the correct scalar ptr for LLVM load/store.
19fn maybe_extract_scalar_ptr(
20    dst: &str,
21    idx: &str,
22    idx_type: &str,
23    dtype: &DType,
24    kernel: &mut Vec<String>,
25) -> (String, String) {
26    if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27        let extract = format!("{dst}.ptr");
28        kernel.push(format!("  {extract} = extractelement {idx_type} {idx}, i32 0"));
29        (extract, "ptr".to_string())
30    } else {
31        (idx.to_string(), idx_type.to_string())
32    }
33}
34
35/// Render a UOp to LLVM IR string.
36///
37/// Returns None for meta-ops that don't produce instructions.
38pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39    let dst = ctx.name(uop);
40
41    match uop.op() {
42        Op::Const(_)
43        | Op::VConst { .. }
44        | Op::Param { device: None, .. }
45        | Op::DefineVar { .. }
46        | Op::Noop
47        | Op::Sink { .. }
48        | Op::Group { .. }
49        | Op::Buffer { .. }
50        | Op::Unique(_)
51        | Op::Device(_)
52        | Op::Call { .. }
53        | Op::Barrier { .. } => None,
54
55        Op::DefineLocal(_) | Op::DefineReg { .. } => {
56            // Emit alloca for local/register memory.
57            // Read base type and size from dtype (matching Tinygrad's x.dtype.base/x.dtype.size).
58            // After devectorize's no_vectorized_buf, dtype is the canonical source of truth.
59            let (base_dtype, alloc_size) = match uop.dtype() {
60                DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61                other => (other, 1),
62            };
63            let base = ldt(&base_dtype);
64            // Tinygrad: DEFINE_LOCAL gets align 16 (for SSE vector loads), DEFINE_REG gets default.
65            let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66            kernel.push(format!("  {dst} = alloca [{alloc_size} x {base}]{align}"));
67            Some(())
68        }
69
70        Op::Index { buffer, indices, .. } => {
71            let buf = ctx.get(buffer);
72            let buf_type = ldt(&buffer.dtype());
73
74            if indices.is_empty() {
75                kernel.push(format!("  {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76            } else {
77                let (final_idx, final_idx_type) = if indices.len() == 1 {
78                    (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
79                } else {
80                    ctx.set_invalid_graph(format!(
81                        "LLVM renderer requires linearized INDEX (single-axis), found {} indices on uop {}",
82                        indices.len(),
83                        uop.id
84                    ));
85                    return None;
86                };
87
88                let elem_type = match uop.dtype() {
89                    svod_dtype::DType::Ptr { ref base, .. } => ldt(base),
90                    other => ldt(&other),
91                };
92
93                // Gate is NOT handled here — matching Tinygrad's approach where INDEX
94                // always emits a plain GEP. The gate is handled at LOAD level (branch+phi)
95                // and at STORE level (IF/ENDIF via line_rewrite_cleanups).
96                kernel.push(format!(
97                    "  {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
98                ));
99            }
100            Some(())
101        }
102
103        Op::PointerIndex { ptr, offset } => {
104            let ptr_val = ctx.get(ptr);
105            let off_val = ctx.get(offset);
106            let elem_type = ldt(&uop.dtype());
107            let ptr_type = ldt(&ptr.dtype());
108            let off_type = ldt(&offset.dtype());
109
110            kernel.push(format!(
111                "  {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
112            ));
113            Some(())
114        }
115
116        Op::Load { index, alt, .. } => {
117            let idx = ctx.get(index);
118            let dtype = ldt(&uop.dtype());
119            let idx_type = ldt(&index.dtype());
120
121            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
122
123            // Gated LOAD: emit branch+phi to avoid null deref.
124            // Matches Tinygrad's pattern (llvmir.py:123-129) which requires BOTH
125            // a gated INDEX and an alt value on the LOAD. If gate exists without
126            // alt, that's a pipeline bug (line_rewrite_cleanups should provide it).
127            // Unwrap one CAST layer to find the INDEX gate (matches Tinygrad's .or_casted("idx")).
128            // The pipeline CAN produce CAST(INDEX) — devectorize handles this shape explicitly.
129            let actual_index = match index.op() {
130                Op::Cast { src, .. } => src,
131                _ => index,
132            };
133            let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
134                let Some(alt_uop) = alt.as_ref() else {
135                    ctx.set_invalid_graph(format!(
136                        "gated LOAD on uop {} has no alt value; line_rewrite_cleanups must lift gated LOADs",
137                        uop.id
138                    ));
139                    return None;
140                };
141                Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
142            } else {
143                None
144            };
145
146            if let Some((gate, alt_val)) = gate_info {
147                let label_base = &dst[1..]; // strip leading %
148                let entry_label = format!("{label_base}_entry");
149                let load_label = format!("{label_base}_load");
150                let exit_label = format!("{label_base}_exit");
151                let load_val = format!("{dst}_yes");
152
153                kernel.push(format!("  br label %{entry_label}"));
154                kernel.push(format!("{entry_label}:"));
155                kernel.push(format!("  br i1 {gate}, label %{load_label}, label %{exit_label}"));
156                kernel.push(format!("{load_label}:"));
157                kernel.push(format!("  {load_val} = load {dtype}, {idx_type} {idx}"));
158                kernel.push(format!("  br label %{exit_label}"));
159                kernel.push(format!("{exit_label}:"));
160                kernel.push(format!("  {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
161            } else {
162                kernel.push(format!("  {dst} = load {dtype}, {idx_type} {idx}"));
163            }
164            Some(())
165        }
166
167        Op::Store { index, value, .. } => {
168            let idx = ctx.get(index);
169            let val = ctx.get(value);
170            let val_type = ldt(&value.dtype());
171            let idx_type = ldt(&index.dtype());
172
173            let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
174
175            kernel.push(format!("  store {val_type} {val}, {idx_type} {idx}"));
176            Some(())
177        }
178
179        Op::Binary(op, lhs, rhs) => {
180            let l = ctx.get(lhs);
181            let r = ctx.get(rhs);
182            let ltype = ldt(&lhs.dtype());
183            let rtype = ldt(&rhs.dtype());
184
185            // Debug: detect type mismatch (logged via tracing)
186            if ltype != rtype {
187                tracing::error!(
188                    uop_id = uop.id,
189                    uop_dtype = ?uop.dtype(),
190                    op = ?op,
191                    lhs_id = lhs.id,
192                    rhs_id = rhs.id,
193                    lhs_dtype = ?lhs.dtype(),
194                    rhs_dtype = ?rhs.dtype(),
195                    lhs_op = ?lhs.op().as_ref(),
196                    rhs_op = ?rhs.op().as_ref(),
197                    "Binary op type mismatch - lhs and rhs have different dtypes"
198                );
199            }
200
201            if matches!(op, BinaryOp::Max) {
202                render_binary_max(&dst, lhs, l, r, &ltype, kernel);
203            } else if matches!(op, BinaryOp::Pow) {
204                render_binary_pow(&dst, lhs, l, r, &ltype, kernel);
205            } else {
206                let instr = binary_instr(*op, &lhs.dtype());
207                kernel.push(format!("  {dst} = {instr} {ltype} {l}, {r}"));
208            }
209            Some(())
210        }
211
212        Op::Unary(op, src) => {
213            let s = ctx.get(src);
214            let stype = ldt(&src.dtype());
215
216            match op {
217                UnaryOp::Neg => {
218                    if src.dtype().is_float() {
219                        kernel.push(format!("  {dst} = fneg {stype} {s}"));
220                    } else {
221                        kernel.push(format!("  {dst} = sub {stype} 0, {s}"));
222                    }
223                }
224                UnaryOp::Not => {
225                    let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
226                    kernel.push(format!("  {dst} = xor {stype} {s}, {all_ones}"));
227                }
228                UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
229                    // Rounding is identity for integer types (defense-in-depth;
230                    // symbolic_simple folds these away upstream).
231                    kernel.push(format!("  {dst} = bitcast {stype} {s} to {stype}"));
232                }
233                UnaryOp::Sqrt
234                | UnaryOp::Exp
235                | UnaryOp::Exp2
236                | UnaryOp::Log
237                | UnaryOp::Log2
238                | UnaryOp::Sin
239                | UnaryOp::Cos
240                | UnaryOp::Floor
241                | UnaryOp::Ceil
242                | UnaryOp::Trunc
243                | UnaryOp::Round => {
244                    let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
245                    render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
246                }
247                UnaryOp::Abs => {
248                    if src.dtype().is_float() {
249                        render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
250                    } else {
251                        render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
252                    }
253                }
254                UnaryOp::Rsqrt => {
255                    let sqrt_dst = format!("{dst}.sqrt");
256                    render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
257                    let one = splat_or_literal("1.0", &src.dtype(), kernel, &dst);
258                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} {one}, {sqrt_dst}"));
259                }
260                UnaryOp::Reciprocal => {
261                    let one = splat_or_literal("1.0", &src.dtype(), kernel, &dst);
262                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} {one}, {s}"));
263                }
264                UnaryOp::Tan => {
265                    let sin_dst = format!("{dst}.sin");
266                    let cos_dst = format!("{dst}.cos");
267                    render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
268                    render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
269                    kernel.push(format!("  {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
270                }
271                UnaryOp::Sign => {
272                    if src.dtype().is_float() {
273                        let gt_zero = format!("{dst}.gt");
274                        let lt_zero = format!("{dst}.lt");
275                        let gt_ext = format!("{dst}.gt_ext");
276                        let lt_ext = format!("{dst}.lt_ext");
277                        let zero = splat_or_literal("0.0", &src.dtype(), kernel, &dst);
278                        kernel.push(format!("  {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, {zero}"));
279                        kernel.push(format!("  {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, {zero}"));
280                        kernel.push(format!("  {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
281                        kernel.push(format!("  {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
282                        kernel.push(format!("  {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
283                    } else if src.dtype().is_signed() {
284                        let gt_zero = format!("{dst}.gt");
285                        let lt_zero = format!("{dst}.lt");
286                        let gt_ext = format!("{dst}.gt_ext");
287                        let lt_ext = format!("{dst}.lt_ext");
288                        let zero = splat_or_literal("0", &src.dtype(), kernel, &dst);
289                        kernel.push(format!("  {gt_zero} = icmp sgt {stype} {s}, {zero}"));
290                        kernel.push(format!("  {lt_zero} = icmp slt {stype} {s}, {zero}"));
291                        kernel.push(format!("  {gt_ext} = zext i1 {gt_zero} to {stype}"));
292                        kernel.push(format!("  {lt_ext} = zext i1 {lt_zero} to {stype}"));
293                        kernel.push(format!("  {dst} = sub {stype} {gt_ext}, {lt_ext}"));
294                    } else {
295                        // Unsigned: sign(x) = (x != 0) ? 1 : 0.
296                        let ne_zero = format!("{dst}.ne");
297                        let zero = splat_or_literal("0", &src.dtype(), kernel, &dst);
298                        kernel.push(format!("  {ne_zero} = icmp ne {stype} {s}, {zero}"));
299                        kernel.push(format!("  {dst} = zext i1 {ne_zero} to {stype}"));
300                    }
301                }
302                UnaryOp::Erf => {
303                    render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
304                }
305                UnaryOp::Square => {
306                    if src.dtype().is_float() {
307                        kernel.push(format!("  {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
308                    } else {
309                        kernel.push(format!("  {dst} = mul {stype} {s}, {s}"));
310                    }
311                }
312            }
313            Some(())
314        }
315
316        Op::Ternary(TernaryOp::Where, cond, t, f) => {
317            let c = ctx.get(cond);
318            let tv = ctx.get(t);
319            let fv = ctx.get(f);
320            kernel.push(format!(
321                "  {dst} = select {} {c}, {} {tv}, {} {fv}",
322                ldt(&cond.dtype()),
323                ldt(&t.dtype()),
324                ldt(&f.dtype())
325            ));
326            Some(())
327        }
328
329        Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
330            let av = ctx.get(a);
331            let bv = ctx.get(b);
332            let cv = ctx.get(c);
333            let dtype = ldt(&a.dtype());
334
335            if a.dtype().is_float() {
336                render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
337            } else {
338                let mul_dst = format!("{dst}.mul");
339                kernel.push(format!("  {mul_dst} = mul {dtype} {av}, {bv}"));
340                kernel.push(format!("  {dst} = add {dtype} {mul_dst}, {cv}"));
341            }
342            Some(())
343        }
344
345        Op::Cast { src, dtype } => {
346            let s = ctx.get(src);
347
348            // INDEX always produces ptr in LLVM (via GEP), regardless of Svod dtype.
349            // When source is INDEX, treat source LLVM type as ptr for cast selection.
350            let is_index_src = matches!(src.op(), Op::Index { .. });
351            let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
352            let dst_llvm_type = ldt(dtype);
353
354            // CAST(INDEX) to Ptr is a no-op - INDEX already produces ptr via GEP.
355            // This matches Tinygrad's approach (llvmir.py:189) where CAST to PtrDType
356            // is register aliasing: r[u] = r[u.src[0]]
357            if is_index_src && matches!(dtype, DType::Ptr { .. }) {
358                // Emit a bitcast as a named no-op to maintain SSA form
359                kernel.push(format!("  {dst} = bitcast ptr {s} to ptr"));
360                return Some(());
361            }
362
363            if dtype.is_bool() && !src.dtype().is_bool() {
364                // Cast to bool: compare != 0 (not trunc, which only takes the low bit).
365                // Matches Tinygrad llvmir.py:99-101.
366                let cmp = if src.dtype().is_float() { "fcmp nsz arcp contract afn une" } else { "icmp ne" };
367                kernel.push(format!("  {dst} = {cmp} {src_llvm_type} {s}, zeroinitializer"));
368            } else if src_llvm_type == dst_llvm_type {
369                kernel.push(format!("  {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
370            } else {
371                let cast_instr = lcast(&src.dtype(), dtype);
372                kernel.push(format!("  {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
373            }
374            Some(())
375        }
376
377        Op::BitCast { src, dtype } => {
378            let s = ctx.get(src);
379            kernel.push(format!("  {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
380            Some(())
381        }
382
383        Op::Range { axis_id, end, .. } => {
384            let id = axis_id.value();
385            let dtype = ldt(&uop.dtype());
386            let end_val = ctx.get(end).to_string();
387
388            // Track range nesting for correct END footer ordering.
389            ctx.push_range(id);
390
391            // Matches Tinygrad llvmir.py:156-165 exactly:
392            //   entry → loop_entry (preheader) → loop_latch (phi+incr+cmp) → loop_body / loop_exit
393            //   loop_body contains body instructions
394            //   END branches to loop_footer → loop_latch (back edge)
395            kernel.push(format!("  br label %loop_entry_{id}"));
396            kernel.push(format!("loop_entry_{id}:"));
397            kernel.push(format!("  br label %loop_latch_{id}"));
398            kernel.push(format!("loop_latch_{id}:"));
399            kernel.push(format!("  {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
400            kernel.push(format!("  {dst}phi = add {dtype} {dst}, 1"));
401            kernel.push(format!("  {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
402            kernel.push(format!("  br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
403            kernel.push(format!("loop_body_{id}:"));
404            Some(())
405        }
406
407        Op::End { ranges, .. } => {
408            // After pm_split_ends, each END has exactly one RANGE.
409            // Use the range_stack to emit footer blocks in correct nesting order
410            // (innermost first = LIFO), regardless of the END's ranges field order.
411            let range_count = ranges.iter().filter(|r| matches!(r.op(), Op::Range { .. })).count();
412            for _ in 0..range_count {
413                if let Some(id) = ctx.pop_range() {
414                    // Matches Tinygrad llvmir.py:166-170 exactly:
415                    //   body → loop_footer → loop_latch (back edge)
416                    //   loop_exit: falls through after loop
417                    kernel.push(format!("  br label %loop_footer_{id}"));
418                    kernel.push(format!("loop_footer_{id}:"));
419                    kernel.push(format!("  br label %loop_latch_{id}"));
420                    kernel.push(format!("loop_exit_{id}:"));
421                }
422            }
423
424            let pending = ctx.take_pending_reduces();
425            for (reduce_id, info) in pending {
426                let result_name = format!("%reduce_{reduce_id}.final");
427                kernel.push(format!("  {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
428                ctx.register(reduce_id, result_name);
429            }
430            Some(())
431        }
432
433        Op::Reduce { src, ranges, reduce_op } => {
434            let src_val = ctx.get(src);
435            let dtype = ldt(&uop.dtype());
436
437            if ranges.is_empty() {
438                kernel.push(format!("  {dst} = bitcast {dtype} {src_val} to {dtype}"));
439            } else {
440                let acc_ptr = format!("%reduce_{}", uop.id);
441                let acc_load = format!("{acc_ptr}.load");
442                let acc_new = format!("{acc_ptr}.new");
443                let instr = reduce_instr(*reduce_op, &uop.dtype());
444
445                kernel.push(format!("  {acc_load} = load {dtype}, ptr {acc_ptr}"));
446
447                if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
448                    render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
449                } else {
450                    kernel.push(format!("  {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
451                }
452
453                kernel.push(format!("  store {dtype} {acc_new}, ptr {acc_ptr}"));
454                ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
455            }
456            Some(())
457        }
458
459        Op::Gep { vector, indices } => {
460            let vec = ctx.get(vector);
461            let vec_type = ldt(&vector.dtype());
462            let out_type = ldt(&uop.dtype());
463
464            if indices.len() == 1 {
465                kernel.push(format!("  {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
466            } else {
467                render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
468            }
469            Some(())
470        }
471
472        Op::Vectorize { elements } => {
473            render_vectorize(&dst, elements, ctx, kernel);
474            Some(())
475        }
476
477        Op::Cat { sources } => {
478            render_cat(&dst, sources, ctx, kernel);
479            Some(())
480        }
481
482        Op::PtrCat { .. } => {
483            panic!(
484                "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
485            );
486        }
487
488        Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
489            let s = ctx.get(src);
490            ctx.alias(uop.id, s.to_string());
491            None
492        }
493
494        Op::Wmma { a, b, c, metadata } => {
495            // Apple AMX matmul.
496            //
497            // Stack slots `wmma_<id>_amx{0,1,2}` were pre-allocated in the
498            // function entry block (see `llvm/text/mod.rs`); LLVM's mem2reg
499            // pass promotes them to registers across loop iterations, which
500            // is the whole reason for using LLVM here over the C backend.
501            //
502            // Per call: store the 3 src vectors into their allocas, then
503            // `ldz×16 + ldx + ldy + fma + stz×16` via AMX inline asm. The C
504            // operand is a flat 256-elem accumulator; A and B are 16-elem
505            // input vectors. The AMX(op, gpr) macro encodes the row index
506            // and byte offset into the gpr for ldz/stz.
507            let a_val = ctx.get(a);
508            let b_val = ctx.get(b);
509            let c_val = ctx.get(c);
510            let a_dtype = ldt(&a.dtype());
511            let b_dtype = ldt(&b.dtype());
512            let c_dtype = ldt(&c.dtype());
513            let a_align = a.dtype().bytes();
514            let b_align = b.dtype().bytes();
515            let c_align = c.dtype().bytes();
516
517            let id = uop.id;
518            let amx0 = format!("%wmma_{id}_amx0");
519            let amx1 = format!("%wmma_{id}_amx1");
520            let amx2 = format!("%wmma_{id}_amx2");
521            let ptr0 = format!("%wmma_{id}_ptr_amx0");
522            let ptr1 = format!("%wmma_{id}_ptr_amx1");
523            let ptr2 = format!("%wmma_{id}_ptr_amx2");
524
525            // 1. Store A, B, C into their pre-allocated stack slots.
526            kernel.push(format!("  store {a_dtype} {a_val}, ptr {amx0}, align {a_align}"));
527            kernel.push(format!("  store {b_dtype} {b_val}, ptr {amx1}, align {b_align}"));
528            kernel.push(format!("  store {c_dtype} {c_val}, ptr {amx2}, align {c_align}"));
529
530            // 2. AMX_SET(0): enable the AMX coprocessor on this thread.
531            // Without this, every subsequent AMX instruction traps with
532            // SIGILL because the coprocessor is in disabled state.
533            // Encoding: `nop;nop;nop;.word (0x201000 + (17 << 5) + 0)`
534            // = `0x201220`.
535            kernel.push(amx_set_inline_asm(0));
536
537            // 3. ldz × N rows of the C accumulator into Z registers.
538            // AMX `ldz` op = 4. Each row is 64 bytes; row index is encoded in bits 56-59 (i*4<<56),
539            // byte offset is bits 0-9 (i*64). The bytes_per_elem in the encoding is fixed at
540            // 4 because AMX TC is fp32-only.
541            let n_rows = metadata.dims.0; // typically 16 for fp32
542            for i in 0..n_rows {
543                let off = ((i as u64 * 4) << 56) | (i as u64 * 64);
544                let ld_name = format!("%wmma_{id}_ld{i}");
545                kernel.push(format!("  {ld_name} = add i64 {ptr2}, {off}"));
546                kernel.push(amx_inline_asm(4, &ld_name));
547            }
548
549            // 4. ldx (A → X), ldy (B → Y), fma32.
550            kernel.push(amx_inline_asm(0, &ptr1));
551            kernel.push(amx_inline_asm(1, &ptr0));
552            kernel.push(amx_inline_asm_imm(12, 0));
553
554            // 5. stz × N rows of Z back into the C accumulator's stack slot.
555            for i in 0..n_rows {
556                let off = ((i as u64 * 4) << 56) | (i as u64 * 64);
557                let st_name = format!("%wmma_{id}_st{i}");
558                kernel.push(format!("  {st_name} = add i64 {ptr2}, {off}"));
559                kernel.push(amx_inline_asm(5, &st_name));
560            }
561
562            // 6. AMX_SET(1): disable the AMX coprocessor. Pairs with the
563            // enable above.
564            kernel.push(amx_set_inline_asm(1));
565
566            // 7. Load the WMMA result back from the C accumulator stack slot.
567            kernel.push(format!("  {dst} = load {c_dtype}, ptr {amx2}, align {c_align}"));
568            Some(())
569        }
570
571        Op::After { passthrough, .. } => {
572            #[cfg(debug_assertions)]
573            if matches!(passthrough.op(), Op::Range { .. }) {
574                panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
575            }
576            let s = ctx.get(passthrough);
577            ctx.alias(uop.id, s.to_string());
578            None
579        }
580
581        Op::Bind { var, value } => {
582            let v = ctx.get(value);
583            ctx.alias(var.id, v.to_string());
584            None
585        }
586
587        Op::If { condition, .. } => {
588            let cond = ctx.get(condition);
589            let if_id = uop.id;
590            kernel.push(format!("  br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
591            kernel.push(format!("if_then_{if_id}:"));
592            Some(())
593        }
594
595        Op::EndIf { if_op } => {
596            let if_id = if_op.id;
597            kernel.push(format!("  br label %if_end_{if_id}"));
598            kernel.push(format!("if_end_{if_id}:"));
599            Some(())
600        }
601
602        // CUSTOM / CUSTOMI are intentionally absent: the LLVM text renderer
603        // rejects them at the entry point with a typed error before reaching
604        // here (see `llvm/text/mod.rs`).
605        op if op.is_movement() => {
606            panic!(
607                "movement op {:?} (id={}) reached LLVM codegen — \
608                 should have been eliminated during rangeify. \
609                 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
610                std::mem::discriminant(op),
611                uop.id,
612            );
613        }
614
615        _ => {
616            kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
617            None
618        }
619    }
620}
621
622/// Materialize a scalar literal as a value usable in a `dtype`-typed
623/// instruction. For scalar `dtype` returns the literal as-is; for vector
624/// `dtype` emits a splat (insertelement + shufflevector) into `kernel`
625/// and returns the resulting SSA name.
626fn splat_or_literal(scalar_lit: &str, dtype: &DType, kernel: &mut Vec<String>, name_hint: &str) -> String {
627    if dtype.vcount() <= 1 {
628        return scalar_lit.to_string();
629    }
630    let scalar_ty = ldt(&dtype.scalar_dtype());
631    let n = dtype.vcount();
632    let splat_z = format!("{name_hint}.splat0");
633    let splat_v = format!("{name_hint}.splat");
634    kernel.push(format!("  {splat_z} = insertelement <1 x {scalar_ty}> poison, {scalar_ty} {scalar_lit}, i32 0"));
635    kernel.push(format!(
636        "  {splat_v} = shufflevector <1 x {scalar_ty}> {splat_z}, \
637         <1 x {scalar_ty}> poison, <{n} x i32> zeroinitializer"
638    ));
639    splat_v
640}
641
642/// Emit an `AMX_SET` instruction that toggles the AMX coprocessor's
643/// per-thread state. `imm5 = 0` enables AMX (must run before any other
644/// AMX instruction); `imm5 = 1` disables it (must run when leaving the
645/// AMX block to release the corruption surface).
646///
647/// Encoding: three NOP cycles to drain the pipeline, then a fixed 32-bit
648/// word at `0x201000 + (17 << 5) + imm5`. `17` is the AMX_SET op slot.
649/// Same encoding as the `AMX_SET` macro in svod's C backend
650/// (`codegen/src/c/amx.rs:39`).
651fn amx_set_inline_asm(imm5: u32) -> String {
652    let opcode = 0x201000u32 + (17 << 5) + imm5;
653    format!(
654        "  tail call void asm sideeffect \"nop\\0Anop\\0Anop\\0A.word ({opcode})\", \
655         \"~{{memory}}\"()"
656    )
657}
658
659/// Emit an Apple AMX inline asm instruction that takes a 64-bit register
660/// operand.
661///
662/// The `.word` directive emits the AMX-encoded instruction; the encoding
663/// `0x201000+(op<<5)+gpr-...` selects the AMX op and which AArch64 GPR
664/// carries the operand. `sideeffect` is required so LLVM doesn't DCE the
665/// AMX state-mutating instruction.
666fn amx_inline_asm(op: u32, gpr_name: &str) -> String {
667    format!(
668        "  tail call void asm sideeffect \".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))\", \
669         \"i,r,~{{memory}}\"(i32 {op}, i64 {gpr_name})"
670    )
671}
672
673/// Emit an AMX inline asm instruction with an immediate operand instead of a
674/// register (used for `fma32` where the operand encoding is `0`).
675fn amx_inline_asm_imm(op: u32, imm: u64) -> String {
676    format!(
677        "  tail call void asm sideeffect \".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))\", \
678         \"i,r,~{{memory}}\"(i32 {op}, i64 {imm})"
679    )
680}
681
682fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
683    assert!(
684        !matches!(dtype.base(), svod_dtype::ScalarDType::Index),
685        "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
686         pm_lower_index_dtype should have lowered it to i32/i64"
687    );
688    let is_float = dtype.is_float();
689    let is_signed = dtype.is_signed();
690
691    match op {
692        BinaryOp::Add => {
693            if is_float {
694                "fadd nsz arcp contract afn"
695            } else if is_signed {
696                "add nsw"
697            } else {
698                "add"
699            }
700        }
701        BinaryOp::Mul => {
702            if is_float {
703                "fmul nsz arcp contract afn"
704            } else {
705                "mul"
706            }
707        }
708        BinaryOp::Sub => {
709            if is_float {
710                "fsub nsz arcp contract afn"
711            } else {
712                "sub"
713            }
714        }
715        BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
716        BinaryOp::Idiv => {
717            if is_signed {
718                "sdiv"
719            } else {
720                "udiv"
721            }
722        }
723        BinaryOp::Mod => {
724            if is_float {
725                "frem nsz arcp contract afn"
726            } else if is_signed {
727                "srem"
728            } else {
729                "urem"
730            }
731        }
732        BinaryOp::Max => {
733            if is_float {
734                "maxnum"
735            } else if is_signed {
736                "smax"
737            } else {
738                "umax"
739            }
740        }
741        BinaryOp::Lt => {
742            if is_float {
743                "fcmp nsz arcp contract afn ult"
744            } else if is_signed {
745                "icmp slt"
746            } else {
747                "icmp ult"
748            }
749        }
750        BinaryOp::Le => {
751            if is_float {
752                "fcmp nsz arcp contract afn ule"
753            } else if is_signed {
754                "icmp sle"
755            } else {
756                "icmp ule"
757            }
758        }
759        BinaryOp::Gt => {
760            if is_float {
761                "fcmp nsz arcp contract afn ugt"
762            } else if is_signed {
763                "icmp sgt"
764            } else {
765                "icmp ugt"
766            }
767        }
768        BinaryOp::Ge => {
769            if is_float {
770                "fcmp nsz arcp contract afn uge"
771            } else if is_signed {
772                "icmp sge"
773            } else {
774                "icmp uge"
775            }
776        }
777        BinaryOp::Eq => {
778            if is_float {
779                "fcmp nsz arcp contract afn oeq"
780            } else {
781                "icmp eq"
782            }
783        }
784        BinaryOp::Ne => {
785            if is_float {
786                "fcmp nsz arcp contract afn une"
787            } else {
788                "icmp ne"
789            }
790        }
791        BinaryOp::And => "and",
792        BinaryOp::Or => "or",
793        BinaryOp::Xor => "xor",
794        BinaryOp::Shl => "shl",
795        BinaryOp::Shr => {
796            if is_signed {
797                "ashr"
798            } else {
799                "lshr"
800            }
801        }
802        BinaryOp::Pow => "pow",
803        BinaryOp::Threefry => "xor",
804    }
805}
806
807fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
808    let is_float = dtype.is_float();
809
810    match op {
811        UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
812        UnaryOp::Not => Some("xor"),
813        UnaryOp::Sqrt => Some("sqrt"),
814        UnaryOp::Rsqrt => None,
815        UnaryOp::Exp => Some("exp"),
816        UnaryOp::Exp2 => Some("exp2"),
817        UnaryOp::Log => Some("log"),
818        UnaryOp::Log2 => Some("log2"),
819        UnaryOp::Sin => Some("sin"),
820        UnaryOp::Cos => Some("cos"),
821        UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
822        UnaryOp::Floor => Some("floor"),
823        UnaryOp::Ceil => Some("ceil"),
824        UnaryOp::Trunc => Some("trunc"),
825        UnaryOp::Round => Some("rint"),
826        UnaryOp::Reciprocal => None,
827        UnaryOp::Tan => None,
828        UnaryOp::Sign => None,
829        UnaryOp::Erf => None,
830        UnaryOp::Square => None,
831    }
832}
833
834fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
835    let is_float = dtype.is_float();
836    let is_signed = dtype.is_signed();
837
838    match op {
839        ReduceOp::Add => {
840            if is_float {
841                "fadd nsz arcp contract afn"
842            } else {
843                "add"
844            }
845        }
846        ReduceOp::Mul => {
847            if is_float {
848                "fmul nsz arcp contract afn"
849            } else {
850                "mul"
851            }
852        }
853        ReduceOp::Max => {
854            if is_float {
855                "maxnum"
856            } else if is_signed {
857                "smax"
858            } else {
859                "umax"
860            }
861        }
862        ReduceOp::Min => {
863            if is_float {
864                "minnum"
865            } else if is_signed {
866                "smin"
867            } else {
868                "umin"
869            }
870        }
871    }
872}
873
874fn mangle_type(llvm_type: &str) -> String {
875    match llvm_type {
876        "float" => "f32".to_string(),
877        "double" => "f64".to_string(),
878        "half" => "f16".to_string(),
879        "i8" => "i8".to_string(),
880        "i16" => "i16".to_string(),
881        "i32" => "i32".to_string(),
882        "i64" => "i64".to_string(),
883        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
884            let inner = &llvm_type[1..llvm_type.len() - 1];
885            let parts: Vec<&str> = inner.split(" x ").collect();
886            if parts.len() == 2 {
887                let count = parts[0].trim();
888                let base = mangle_type(parts[1].trim());
889                format!("v{count}{base}")
890            } else {
891                llvm_type.to_string()
892            }
893        }
894        _ => llvm_type.to_string(),
895    }
896}
897
898fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
899    let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
900    let mangled = mangle_type(ret_type);
901    kernel.push(format!("  {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
902}
903
904fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
905    if lhs.dtype().is_float() {
906        render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
907    } else {
908        let is_signed = lhs.dtype().is_signed();
909        let cmp = if is_signed { "sgt" } else { "ugt" };
910        let cmp_dst = format!("{dst}.cmp");
911        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
912        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
913    }
914}
915
916fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
917    if lhs.dtype().is_float() {
918        render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
919    } else {
920        let l_float = format!("{dst}.lf");
921        let r_float = format!("{dst}.rf");
922        let pow_float = format!("{dst}.pf");
923        kernel.push(format!("  {l_float} = sitofp {ltype} {l} to double"));
924        kernel.push(format!("  {r_float} = sitofp {ltype} {r} to double"));
925        render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
926        kernel.push(format!("  {dst} = fptosi double {pow_float} to {ltype}"));
927    }
928}
929
930fn render_reduce_minmax(
931    dst: &str,
932    op: ReduceOp,
933    dtype: &DType,
934    acc: &str,
935    val: &str,
936    ltype: &str,
937    kernel: &mut Vec<String>,
938) {
939    if dtype.is_float() {
940        let intrinsic = match op {
941            ReduceOp::Max => "maxnum",
942            ReduceOp::Min => "minnum",
943            _ => unreachable!(),
944        };
945        render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
946    } else {
947        let is_signed = dtype.is_signed();
948        let cmp = match op {
949            ReduceOp::Max => {
950                if is_signed {
951                    "sgt"
952                } else {
953                    "ugt"
954                }
955            }
956            ReduceOp::Min => {
957                if is_signed {
958                    "slt"
959                } else {
960                    "ult"
961                }
962            }
963            _ => unreachable!(),
964        };
965        let cmp_dst = format!("{dst}.cmp");
966        kernel.push(format!("  {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
967        kernel.push(format!("  {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
968    }
969}
970
971fn render_multi_gep(
972    dst: &str,
973    vec: &str,
974    vec_dtype: &DType,
975    indices: &[usize],
976    out_type: &str,
977    kernel: &mut Vec<String>,
978) {
979    let vec_type = ldt(vec_dtype);
980
981    let elem_dtype = match vec_dtype {
982        DType::Ptr { base, addrspace, size, .. } => {
983            DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
984        }
985        DType::Vector { scalar, .. } => DType::Scalar(*scalar),
986        _ => DType::Scalar(vec_dtype.base()),
987    };
988    let elem_type = ldt(&elem_dtype);
989
990    for (i, &idx) in indices.iter().enumerate() {
991        let elem = format!("{dst}.e{i}");
992        kernel.push(format!("  {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
993    }
994
995    if indices.len() == 1 {
996        kernel.push(format!("  {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
997    } else {
998        let count = indices.len();
999        kernel.push(format!("  {dst}.undef = undef <{count} x {elem_type}>"));
1000        let mut prev = format!("{dst}.undef");
1001        for i in 0..count {
1002            let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
1003            kernel.push(format!(
1004                "  {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
1005            ));
1006            prev = next;
1007        }
1008    }
1009}
1010
1011fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
1012    if elements.is_empty() {
1013        return;
1014    }
1015
1016    let scalar_type = ldt(&elements[0].dtype());
1017    let count = elements.len();
1018    let vec_type = format!("<{count} x {scalar_type}>");
1019
1020    let mut prev = "undef".to_string();
1021    for (i, elem) in elements.iter().enumerate() {
1022        let val = ctx.get(elem);
1023        let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
1024        kernel.push(format!("  {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
1025        prev = next;
1026    }
1027}
1028
1029fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
1030    if sources.is_empty() {
1031        return;
1032    }
1033
1034    let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
1035    let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
1036    let out_type = format!("<{total_count} x {scalar_type}>");
1037
1038    let mut out_idx = 0;
1039    let mut prev = "undef".to_string();
1040
1041    for src in sources.iter() {
1042        let src_val = ctx.get(src);
1043        let src_count = src.dtype().vcount();
1044
1045        if src_count == 1 {
1046            let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
1047            kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
1048            prev = next;
1049            out_idx += 1;
1050        } else {
1051            let src_type = ldt(&src.dtype());
1052            for i in 0..src_count {
1053                let elem = format!("{dst}.e{out_idx}");
1054                kernel.push(format!("  {elem} = extractelement {src_type} {src_val}, i32 {i}"));
1055
1056                let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
1057                kernel.push(format!("  {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
1058                prev = next;
1059                out_idx += 1;
1060            }
1061        }
1062    }
1063}
1064
1065/// Get identity element for reduce operation.
1066pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
1067    let is_vector = matches!(dtype, DType::Vector { .. });
1068
1069    match op {
1070        ReduceOp::Add => {
1071            if is_vector {
1072                "zeroinitializer".to_string()
1073            } else if dtype.is_float() {
1074                "0.0".to_string()
1075            } else {
1076                "0".to_string()
1077            }
1078        }
1079        ReduceOp::Mul => {
1080            if is_vector {
1081                "zeroinitializer".to_string()
1082            } else if dtype.is_float() {
1083                "1.0".to_string()
1084            } else {
1085                "1".to_string()
1086            }
1087        }
1088        ReduceOp::Max => {
1089            if is_vector {
1090                "zeroinitializer".to_string()
1091            } else if dtype.is_float() {
1092                "-0x7FF0000000000000".to_string()
1093            } else if dtype.is_signed() {
1094                i64::MIN.to_string()
1095            } else {
1096                "0".to_string()
1097            }
1098        }
1099        ReduceOp::Min => {
1100            if is_vector {
1101                "zeroinitializer".to_string() // TODO: proper +inf splat
1102            } else if dtype.is_float() {
1103                "0x7FF0000000000000".to_string() // +inf
1104            } else if dtype.is_signed() {
1105                i64::MAX.to_string()
1106            } else {
1107                u64::MAX.to_string()
1108            }
1109        }
1110    }
1111}