1use std::sync::Arc;
7
8use svod_dtype::DType;
9use svod_ir::{BinaryOp, Op, ReduceOp, TernaryOp, UnaryOp, prelude::*};
10
11use crate::llvm::common::{RenderContext, lcast, ldt};
12
13fn maybe_extract_scalar_ptr(
20 dst: &str,
21 idx: &str,
22 idx_type: &str,
23 dtype: &DType,
24 kernel: &mut Vec<String>,
25) -> (String, String) {
26 if matches!(dtype, DType::Ptr { vcount, .. } if *vcount > 1) {
27 let extract = format!("{dst}.ptr");
28 kernel.push(format!(" {extract} = extractelement {idx_type} {idx}, i32 0"));
29 (extract, "ptr".to_string())
30 } else {
31 (idx.to_string(), idx_type.to_string())
32 }
33}
34
35pub fn render_uop(uop: &Arc<UOp>, ctx: &mut RenderContext, kernel: &mut Vec<String>) -> Option<()> {
39 let dst = ctx.name(uop);
40
41 match uop.op() {
42 Op::Const(_)
43 | Op::VConst { .. }
44 | Op::Param { device: None, .. }
45 | Op::DefineVar { .. }
46 | Op::Noop
47 | Op::Sink { .. }
48 | Op::Group { .. }
49 | Op::Buffer { .. }
50 | Op::Unique(_)
51 | Op::Device(_)
52 | Op::Call { .. }
53 | Op::Barrier { .. } => None,
54
55 Op::DefineLocal(_) | Op::DefineReg { .. } => {
56 let (base_dtype, alloc_size) = match uop.dtype() {
60 DType::Ptr { base, size, .. } => (base.as_ref().clone(), size.unwrap_or(1)),
61 other => (other, 1),
62 };
63 let base = ldt(&base_dtype);
64 let align = if matches!(uop.op(), Op::DefineLocal(_)) { ", align 16" } else { "" };
66 kernel.push(format!(" {dst} = alloca [{alloc_size} x {base}]{align}"));
67 Some(())
68 }
69
70 Op::Index { buffer, indices, .. } => {
71 let buf = ctx.get(buffer);
72 let buf_type = ldt(&buffer.dtype());
73
74 if indices.is_empty() {
75 kernel.push(format!(" {dst} = bitcast {buf_type} {buf} to {}", ldt(&uop.dtype())));
76 } else {
77 let (final_idx, final_idx_type) = if indices.len() == 1 {
78 (ctx.get(&indices[0]).to_string(), ldt(&indices[0].dtype()))
79 } else {
80 ctx.set_invalid_graph(format!(
81 "LLVM renderer requires linearized INDEX (single-axis), found {} indices on uop {}",
82 indices.len(),
83 uop.id
84 ));
85 return None;
86 };
87
88 let elem_type = match uop.dtype() {
89 svod_dtype::DType::Ptr { ref base, .. } => ldt(base),
90 other => ldt(&other),
91 };
92
93 kernel.push(format!(
97 " {dst} = getelementptr inbounds {elem_type}, {buf_type} {buf}, {final_idx_type} {final_idx}"
98 ));
99 }
100 Some(())
101 }
102
103 Op::PointerIndex { ptr, offset } => {
104 let ptr_val = ctx.get(ptr);
105 let off_val = ctx.get(offset);
106 let elem_type = ldt(&uop.dtype());
107 let ptr_type = ldt(&ptr.dtype());
108 let off_type = ldt(&offset.dtype());
109
110 kernel.push(format!(
111 " {dst} = getelementptr inbounds {elem_type}, {ptr_type} {ptr_val}, {off_type} {off_val}"
112 ));
113 Some(())
114 }
115
116 Op::Load { index, alt, .. } => {
117 let idx = ctx.get(index);
118 let dtype = ldt(&uop.dtype());
119 let idx_type = ldt(&index.dtype());
120
121 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
122
123 let actual_index = match index.op() {
130 Op::Cast { src, .. } => src,
131 _ => index,
132 };
133 let gate_info = if let Op::Index { gate: Some(gate_uop), .. } = actual_index.op() {
134 let Some(alt_uop) = alt.as_ref() else {
135 ctx.set_invalid_graph(format!(
136 "gated LOAD on uop {} has no alt value; line_rewrite_cleanups must lift gated LOADs",
137 uop.id
138 ));
139 return None;
140 };
141 Some((ctx.get(gate_uop).to_string(), ctx.get(alt_uop).to_string()))
142 } else {
143 None
144 };
145
146 if let Some((gate, alt_val)) = gate_info {
147 let label_base = &dst[1..]; let entry_label = format!("{label_base}_entry");
149 let load_label = format!("{label_base}_load");
150 let exit_label = format!("{label_base}_exit");
151 let load_val = format!("{dst}_yes");
152
153 kernel.push(format!(" br label %{entry_label}"));
154 kernel.push(format!("{entry_label}:"));
155 kernel.push(format!(" br i1 {gate}, label %{load_label}, label %{exit_label}"));
156 kernel.push(format!("{load_label}:"));
157 kernel.push(format!(" {load_val} = load {dtype}, {idx_type} {idx}"));
158 kernel.push(format!(" br label %{exit_label}"));
159 kernel.push(format!("{exit_label}:"));
160 kernel.push(format!(" {dst} = phi {dtype} [{load_val}, %{load_label}], [{alt_val}, %{entry_label}]"));
161 } else {
162 kernel.push(format!(" {dst} = load {dtype}, {idx_type} {idx}"));
163 }
164 Some(())
165 }
166
167 Op::Store { index, value, .. } => {
168 let idx = ctx.get(index);
169 let val = ctx.get(value);
170 let val_type = ldt(&value.dtype());
171 let idx_type = ldt(&index.dtype());
172
173 let (idx, idx_type) = maybe_extract_scalar_ptr(&dst, idx, &idx_type, &index.dtype(), kernel);
174
175 kernel.push(format!(" store {val_type} {val}, {idx_type} {idx}"));
176 Some(())
177 }
178
179 Op::Binary(op, lhs, rhs) => {
180 let l = ctx.get(lhs);
181 let r = ctx.get(rhs);
182 let ltype = ldt(&lhs.dtype());
183 let rtype = ldt(&rhs.dtype());
184
185 if ltype != rtype {
187 tracing::error!(
188 uop_id = uop.id,
189 uop_dtype = ?uop.dtype(),
190 op = ?op,
191 lhs_id = lhs.id,
192 rhs_id = rhs.id,
193 lhs_dtype = ?lhs.dtype(),
194 rhs_dtype = ?rhs.dtype(),
195 lhs_op = ?lhs.op().as_ref(),
196 rhs_op = ?rhs.op().as_ref(),
197 "Binary op type mismatch - lhs and rhs have different dtypes"
198 );
199 }
200
201 if matches!(op, BinaryOp::Max) {
202 render_binary_max(&dst, lhs, l, r, <ype, kernel);
203 } else if matches!(op, BinaryOp::Pow) {
204 render_binary_pow(&dst, lhs, l, r, <ype, kernel);
205 } else {
206 let instr = binary_instr(*op, &lhs.dtype());
207 kernel.push(format!(" {dst} = {instr} {ltype} {l}, {r}"));
208 }
209 Some(())
210 }
211
212 Op::Unary(op, src) => {
213 let s = ctx.get(src);
214 let stype = ldt(&src.dtype());
215
216 match op {
217 UnaryOp::Neg => {
218 if src.dtype().is_float() {
219 kernel.push(format!(" {dst} = fneg {stype} {s}"));
220 } else {
221 kernel.push(format!(" {dst} = sub {stype} 0, {s}"));
222 }
223 }
224 UnaryOp::Not => {
225 let all_ones = if src.dtype().is_bool() { "1".to_string() } else { "-1".to_string() };
226 kernel.push(format!(" {dst} = xor {stype} {s}, {all_ones}"));
227 }
228 UnaryOp::Floor | UnaryOp::Ceil | UnaryOp::Trunc | UnaryOp::Round if !src.dtype().is_float() => {
229 kernel.push(format!(" {dst} = bitcast {stype} {s} to {stype}"));
232 }
233 UnaryOp::Sqrt
234 | UnaryOp::Exp
235 | UnaryOp::Exp2
236 | UnaryOp::Log
237 | UnaryOp::Log2
238 | UnaryOp::Sin
239 | UnaryOp::Cos
240 | UnaryOp::Floor
241 | UnaryOp::Ceil
242 | UnaryOp::Trunc
243 | UnaryOp::Round => {
244 let intrinsic = unary_instr(*op, &src.dtype()).unwrap();
245 render_intrinsic(&dst, intrinsic, &[(&stype, s)], &stype, kernel);
246 }
247 UnaryOp::Abs => {
248 if src.dtype().is_float() {
249 render_intrinsic(&dst, "fabs", &[(&stype, s)], &stype, kernel);
250 } else {
251 render_intrinsic(&dst, "abs", &[(&stype, s), ("i1", "1")], &stype, kernel);
252 }
253 }
254 UnaryOp::Rsqrt => {
255 let sqrt_dst = format!("{dst}.sqrt");
256 render_intrinsic(&sqrt_dst, "sqrt", &[(&stype, s)], &stype, kernel);
257 let one = splat_or_literal("1.0", &src.dtype(), kernel, &dst);
258 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} {one}, {sqrt_dst}"));
259 }
260 UnaryOp::Reciprocal => {
261 let one = splat_or_literal("1.0", &src.dtype(), kernel, &dst);
262 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} {one}, {s}"));
263 }
264 UnaryOp::Tan => {
265 let sin_dst = format!("{dst}.sin");
266 let cos_dst = format!("{dst}.cos");
267 render_intrinsic(&sin_dst, "sin", &[(&stype, s)], &stype, kernel);
268 render_intrinsic(&cos_dst, "cos", &[(&stype, s)], &stype, kernel);
269 kernel.push(format!(" {dst} = fdiv nsz arcp contract afn {stype} {sin_dst}, {cos_dst}"));
270 }
271 UnaryOp::Sign => {
272 if src.dtype().is_float() {
273 let gt_zero = format!("{dst}.gt");
274 let lt_zero = format!("{dst}.lt");
275 let gt_ext = format!("{dst}.gt_ext");
276 let lt_ext = format!("{dst}.lt_ext");
277 let zero = splat_or_literal("0.0", &src.dtype(), kernel, &dst);
278 kernel.push(format!(" {gt_zero} = fcmp nsz arcp contract afn ogt {stype} {s}, {zero}"));
279 kernel.push(format!(" {lt_zero} = fcmp nsz arcp contract afn olt {stype} {s}, {zero}"));
280 kernel.push(format!(" {gt_ext} = uitofp i1 {gt_zero} to {stype}"));
281 kernel.push(format!(" {lt_ext} = uitofp i1 {lt_zero} to {stype}"));
282 kernel.push(format!(" {dst} = fsub nsz arcp contract afn {stype} {gt_ext}, {lt_ext}"));
283 } else if src.dtype().is_signed() {
284 let gt_zero = format!("{dst}.gt");
285 let lt_zero = format!("{dst}.lt");
286 let gt_ext = format!("{dst}.gt_ext");
287 let lt_ext = format!("{dst}.lt_ext");
288 let zero = splat_or_literal("0", &src.dtype(), kernel, &dst);
289 kernel.push(format!(" {gt_zero} = icmp sgt {stype} {s}, {zero}"));
290 kernel.push(format!(" {lt_zero} = icmp slt {stype} {s}, {zero}"));
291 kernel.push(format!(" {gt_ext} = zext i1 {gt_zero} to {stype}"));
292 kernel.push(format!(" {lt_ext} = zext i1 {lt_zero} to {stype}"));
293 kernel.push(format!(" {dst} = sub {stype} {gt_ext}, {lt_ext}"));
294 } else {
295 let ne_zero = format!("{dst}.ne");
297 let zero = splat_or_literal("0", &src.dtype(), kernel, &dst);
298 kernel.push(format!(" {ne_zero} = icmp ne {stype} {s}, {zero}"));
299 kernel.push(format!(" {dst} = zext i1 {ne_zero} to {stype}"));
300 }
301 }
302 UnaryOp::Erf => {
303 render_intrinsic(&dst, "erf", &[(&stype, s)], &stype, kernel);
304 }
305 UnaryOp::Square => {
306 if src.dtype().is_float() {
307 kernel.push(format!(" {dst} = fmul nsz arcp contract afn {stype} {s}, {s}"));
308 } else {
309 kernel.push(format!(" {dst} = mul {stype} {s}, {s}"));
310 }
311 }
312 }
313 Some(())
314 }
315
316 Op::Ternary(TernaryOp::Where, cond, t, f) => {
317 let c = ctx.get(cond);
318 let tv = ctx.get(t);
319 let fv = ctx.get(f);
320 kernel.push(format!(
321 " {dst} = select {} {c}, {} {tv}, {} {fv}",
322 ldt(&cond.dtype()),
323 ldt(&t.dtype()),
324 ldt(&f.dtype())
325 ));
326 Some(())
327 }
328
329 Op::Ternary(TernaryOp::MulAcc, a, b, c) => {
330 let av = ctx.get(a);
331 let bv = ctx.get(b);
332 let cv = ctx.get(c);
333 let dtype = ldt(&a.dtype());
334
335 if a.dtype().is_float() {
336 render_intrinsic(&dst, "fmuladd", &[(&dtype, av), (&dtype, bv), (&dtype, cv)], &dtype, kernel);
337 } else {
338 let mul_dst = format!("{dst}.mul");
339 kernel.push(format!(" {mul_dst} = mul {dtype} {av}, {bv}"));
340 kernel.push(format!(" {dst} = add {dtype} {mul_dst}, {cv}"));
341 }
342 Some(())
343 }
344
345 Op::Cast { src, dtype } => {
346 let s = ctx.get(src);
347
348 let is_index_src = matches!(src.op(), Op::Index { .. });
351 let src_llvm_type = if is_index_src { "ptr".to_string() } else { ldt(&src.dtype()) };
352 let dst_llvm_type = ldt(dtype);
353
354 if is_index_src && matches!(dtype, DType::Ptr { .. }) {
358 kernel.push(format!(" {dst} = bitcast ptr {s} to ptr"));
360 return Some(());
361 }
362
363 if dtype.is_bool() && !src.dtype().is_bool() {
364 let cmp = if src.dtype().is_float() { "fcmp nsz arcp contract afn une" } else { "icmp ne" };
367 kernel.push(format!(" {dst} = {cmp} {src_llvm_type} {s}, zeroinitializer"));
368 } else if src_llvm_type == dst_llvm_type {
369 kernel.push(format!(" {dst} = bitcast {src_llvm_type} {s} to {dst_llvm_type}"));
370 } else {
371 let cast_instr = lcast(&src.dtype(), dtype);
372 kernel.push(format!(" {dst} = {cast_instr} {src_llvm_type} {s} to {dst_llvm_type}"));
373 }
374 Some(())
375 }
376
377 Op::BitCast { src, dtype } => {
378 let s = ctx.get(src);
379 kernel.push(format!(" {dst} = bitcast {} {s} to {}", ldt(&src.dtype()), ldt(dtype)));
380 Some(())
381 }
382
383 Op::Range { axis_id, end, .. } => {
384 let id = axis_id.value();
385 let dtype = ldt(&uop.dtype());
386 let end_val = ctx.get(end).to_string();
387
388 ctx.push_range(id);
390
391 kernel.push(format!(" br label %loop_entry_{id}"));
396 kernel.push(format!("loop_entry_{id}:"));
397 kernel.push(format!(" br label %loop_latch_{id}"));
398 kernel.push(format!("loop_latch_{id}:"));
399 kernel.push(format!(" {dst} = phi {dtype} [ 0, %loop_entry_{id} ], [ {dst}phi, %loop_footer_{id} ]"));
400 kernel.push(format!(" {dst}phi = add {dtype} {dst}, 1"));
401 kernel.push(format!(" {dst}cmp = icmp ult {dtype} {dst}, {end_val}"));
402 kernel.push(format!(" br i1 {dst}cmp, label %loop_body_{id}, label %loop_exit_{id}"));
403 kernel.push(format!("loop_body_{id}:"));
404 Some(())
405 }
406
407 Op::End { ranges, .. } => {
408 let range_count = ranges.iter().filter(|r| matches!(r.op(), Op::Range { .. })).count();
412 for _ in 0..range_count {
413 if let Some(id) = ctx.pop_range() {
414 kernel.push(format!(" br label %loop_footer_{id}"));
418 kernel.push(format!("loop_footer_{id}:"));
419 kernel.push(format!(" br label %loop_latch_{id}"));
420 kernel.push(format!("loop_exit_{id}:"));
421 }
422 }
423
424 let pending = ctx.take_pending_reduces();
425 for (reduce_id, info) in pending {
426 let result_name = format!("%reduce_{reduce_id}.final");
427 kernel.push(format!(" {result_name} = load {}, ptr {}", info.dtype, info.acc_ptr));
428 ctx.register(reduce_id, result_name);
429 }
430 Some(())
431 }
432
433 Op::Reduce { src, ranges, reduce_op } => {
434 let src_val = ctx.get(src);
435 let dtype = ldt(&uop.dtype());
436
437 if ranges.is_empty() {
438 kernel.push(format!(" {dst} = bitcast {dtype} {src_val} to {dtype}"));
439 } else {
440 let acc_ptr = format!("%reduce_{}", uop.id);
441 let acc_load = format!("{acc_ptr}.load");
442 let acc_new = format!("{acc_ptr}.new");
443 let instr = reduce_instr(*reduce_op, &uop.dtype());
444
445 kernel.push(format!(" {acc_load} = load {dtype}, ptr {acc_ptr}"));
446
447 if matches!(reduce_op, ReduceOp::Max | ReduceOp::Min) {
448 render_reduce_minmax(&acc_new, *reduce_op, &uop.dtype(), &acc_load, src_val, &dtype, kernel);
449 } else {
450 kernel.push(format!(" {acc_new} = {instr} {dtype} {acc_load}, {src_val}"));
451 }
452
453 kernel.push(format!(" store {dtype} {acc_new}, ptr {acc_ptr}"));
454 ctx.register_reduce_pending(uop.id, acc_ptr.clone(), dtype.clone());
455 }
456 Some(())
457 }
458
459 Op::Gep { vector, indices } => {
460 let vec = ctx.get(vector);
461 let vec_type = ldt(&vector.dtype());
462 let out_type = ldt(&uop.dtype());
463
464 if indices.len() == 1 {
465 kernel.push(format!(" {dst} = extractelement {vec_type} {vec}, i32 {}", indices[0]));
466 } else {
467 render_multi_gep(&dst, vec, &vector.dtype(), indices, &out_type, kernel);
468 }
469 Some(())
470 }
471
472 Op::Vectorize { elements } => {
473 render_vectorize(&dst, elements, ctx, kernel);
474 Some(())
475 }
476
477 Op::Cat { sources } => {
478 render_cat(&dst, sources, ctx, kernel);
479 Some(())
480 }
481
482 Op::PtrCat { .. } => {
483 panic!(
484 "PtrCat must be eliminated before codegen (devectorize should distribute it into scalar loads/stores)"
485 );
486 }
487
488 Op::Contract { src, .. } | Op::Unroll { src, .. } | Op::Detach { src } => {
489 let s = ctx.get(src);
490 ctx.alias(uop.id, s.to_string());
491 None
492 }
493
494 Op::Wmma { a, b, c, metadata } => {
495 let a_val = ctx.get(a);
508 let b_val = ctx.get(b);
509 let c_val = ctx.get(c);
510 let a_dtype = ldt(&a.dtype());
511 let b_dtype = ldt(&b.dtype());
512 let c_dtype = ldt(&c.dtype());
513 let a_align = a.dtype().bytes();
514 let b_align = b.dtype().bytes();
515 let c_align = c.dtype().bytes();
516
517 let id = uop.id;
518 let amx0 = format!("%wmma_{id}_amx0");
519 let amx1 = format!("%wmma_{id}_amx1");
520 let amx2 = format!("%wmma_{id}_amx2");
521 let ptr0 = format!("%wmma_{id}_ptr_amx0");
522 let ptr1 = format!("%wmma_{id}_ptr_amx1");
523 let ptr2 = format!("%wmma_{id}_ptr_amx2");
524
525 kernel.push(format!(" store {a_dtype} {a_val}, ptr {amx0}, align {a_align}"));
527 kernel.push(format!(" store {b_dtype} {b_val}, ptr {amx1}, align {b_align}"));
528 kernel.push(format!(" store {c_dtype} {c_val}, ptr {amx2}, align {c_align}"));
529
530 kernel.push(amx_set_inline_asm(0));
536
537 let n_rows = metadata.dims.0; for i in 0..n_rows {
543 let off = ((i as u64 * 4) << 56) | (i as u64 * 64);
544 let ld_name = format!("%wmma_{id}_ld{i}");
545 kernel.push(format!(" {ld_name} = add i64 {ptr2}, {off}"));
546 kernel.push(amx_inline_asm(4, &ld_name));
547 }
548
549 kernel.push(amx_inline_asm(0, &ptr1));
551 kernel.push(amx_inline_asm(1, &ptr0));
552 kernel.push(amx_inline_asm_imm(12, 0));
553
554 for i in 0..n_rows {
556 let off = ((i as u64 * 4) << 56) | (i as u64 * 64);
557 let st_name = format!("%wmma_{id}_st{i}");
558 kernel.push(format!(" {st_name} = add i64 {ptr2}, {off}"));
559 kernel.push(amx_inline_asm(5, &st_name));
560 }
561
562 kernel.push(amx_set_inline_asm(1));
565
566 kernel.push(format!(" {dst} = load {c_dtype}, ptr {amx2}, align {c_align}"));
568 Some(())
569 }
570
571 Op::After { passthrough, .. } => {
572 #[cfg(debug_assertions)]
573 if matches!(passthrough.op(), Op::Range { .. }) {
574 panic!("AFTER passthrough is Range (id={}), this violates Tinygrad semantics", passthrough.id);
575 }
576 let s = ctx.get(passthrough);
577 ctx.alias(uop.id, s.to_string());
578 None
579 }
580
581 Op::Bind { var, value } => {
582 let v = ctx.get(value);
583 ctx.alias(var.id, v.to_string());
584 None
585 }
586
587 Op::If { condition, .. } => {
588 let cond = ctx.get(condition);
589 let if_id = uop.id;
590 kernel.push(format!(" br i1 {cond}, label %if_then_{if_id}, label %if_end_{if_id}"));
591 kernel.push(format!("if_then_{if_id}:"));
592 Some(())
593 }
594
595 Op::EndIf { if_op } => {
596 let if_id = if_op.id;
597 kernel.push(format!(" br label %if_end_{if_id}"));
598 kernel.push(format!("if_end_{if_id}:"));
599 Some(())
600 }
601
602 op if op.is_movement() => {
606 panic!(
607 "movement op {:?} (id={}) reached LLVM codegen — \
608 should have been eliminated during rangeify. \
609 This indicates a bug in remove_movement_op or apply_bufferize_transform.",
610 std::mem::discriminant(op),
611 uop.id,
612 );
613 }
614
615 _ => {
616 kernel.push(format!("; UNSUPPORTED: {:?}", uop.op()));
617 None
618 }
619 }
620}
621
622fn splat_or_literal(scalar_lit: &str, dtype: &DType, kernel: &mut Vec<String>, name_hint: &str) -> String {
627 if dtype.vcount() <= 1 {
628 return scalar_lit.to_string();
629 }
630 let scalar_ty = ldt(&dtype.scalar_dtype());
631 let n = dtype.vcount();
632 let splat_z = format!("{name_hint}.splat0");
633 let splat_v = format!("{name_hint}.splat");
634 kernel.push(format!(" {splat_z} = insertelement <1 x {scalar_ty}> poison, {scalar_ty} {scalar_lit}, i32 0"));
635 kernel.push(format!(
636 " {splat_v} = shufflevector <1 x {scalar_ty}> {splat_z}, \
637 <1 x {scalar_ty}> poison, <{n} x i32> zeroinitializer"
638 ));
639 splat_v
640}
641
642fn amx_set_inline_asm(imm5: u32) -> String {
652 let opcode = 0x201000u32 + (17 << 5) + imm5;
653 format!(
654 " tail call void asm sideeffect \"nop\\0Anop\\0Anop\\0A.word ({opcode})\", \
655 \"~{{memory}}\"()"
656 )
657}
658
659fn amx_inline_asm(op: u32, gpr_name: &str) -> String {
667 format!(
668 " tail call void asm sideeffect \".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))\", \
669 \"i,r,~{{memory}}\"(i32 {op}, i64 {gpr_name})"
670 )
671}
672
673fn amx_inline_asm_imm(op: u32, imm: u64) -> String {
676 format!(
677 " tail call void asm sideeffect \".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))\", \
678 \"i,r,~{{memory}}\"(i32 {op}, i64 {imm})"
679 )
680}
681
682fn binary_instr(op: BinaryOp, dtype: &DType) -> &'static str {
683 assert!(
684 !matches!(dtype.base(), svod_dtype::ScalarDType::Index),
685 "Index dtype reached LLVM codegen binary_instr({op:?}, {dtype:?}) — \
686 pm_lower_index_dtype should have lowered it to i32/i64"
687 );
688 let is_float = dtype.is_float();
689 let is_signed = dtype.is_signed();
690
691 match op {
692 BinaryOp::Add => {
693 if is_float {
694 "fadd nsz arcp contract afn"
695 } else if is_signed {
696 "add nsw"
697 } else {
698 "add"
699 }
700 }
701 BinaryOp::Mul => {
702 if is_float {
703 "fmul nsz arcp contract afn"
704 } else {
705 "mul"
706 }
707 }
708 BinaryOp::Sub => {
709 if is_float {
710 "fsub nsz arcp contract afn"
711 } else {
712 "sub"
713 }
714 }
715 BinaryOp::Fdiv => "fdiv nsz arcp contract afn",
716 BinaryOp::Idiv => {
717 if is_signed {
718 "sdiv"
719 } else {
720 "udiv"
721 }
722 }
723 BinaryOp::Mod => {
724 if is_float {
725 "frem nsz arcp contract afn"
726 } else if is_signed {
727 "srem"
728 } else {
729 "urem"
730 }
731 }
732 BinaryOp::Max => {
733 if is_float {
734 "maxnum"
735 } else if is_signed {
736 "smax"
737 } else {
738 "umax"
739 }
740 }
741 BinaryOp::Lt => {
742 if is_float {
743 "fcmp nsz arcp contract afn ult"
744 } else if is_signed {
745 "icmp slt"
746 } else {
747 "icmp ult"
748 }
749 }
750 BinaryOp::Le => {
751 if is_float {
752 "fcmp nsz arcp contract afn ule"
753 } else if is_signed {
754 "icmp sle"
755 } else {
756 "icmp ule"
757 }
758 }
759 BinaryOp::Gt => {
760 if is_float {
761 "fcmp nsz arcp contract afn ugt"
762 } else if is_signed {
763 "icmp sgt"
764 } else {
765 "icmp ugt"
766 }
767 }
768 BinaryOp::Ge => {
769 if is_float {
770 "fcmp nsz arcp contract afn uge"
771 } else if is_signed {
772 "icmp sge"
773 } else {
774 "icmp uge"
775 }
776 }
777 BinaryOp::Eq => {
778 if is_float {
779 "fcmp nsz arcp contract afn oeq"
780 } else {
781 "icmp eq"
782 }
783 }
784 BinaryOp::Ne => {
785 if is_float {
786 "fcmp nsz arcp contract afn une"
787 } else {
788 "icmp ne"
789 }
790 }
791 BinaryOp::And => "and",
792 BinaryOp::Or => "or",
793 BinaryOp::Xor => "xor",
794 BinaryOp::Shl => "shl",
795 BinaryOp::Shr => {
796 if is_signed {
797 "ashr"
798 } else {
799 "lshr"
800 }
801 }
802 BinaryOp::Pow => "pow",
803 BinaryOp::Threefry => "xor",
804 }
805}
806
807fn unary_instr(op: UnaryOp, dtype: &DType) -> Option<&'static str> {
808 let is_float = dtype.is_float();
809
810 match op {
811 UnaryOp::Neg => Some(if is_float { "fneg" } else { "sub" }),
812 UnaryOp::Not => Some("xor"),
813 UnaryOp::Sqrt => Some("sqrt"),
814 UnaryOp::Rsqrt => None,
815 UnaryOp::Exp => Some("exp"),
816 UnaryOp::Exp2 => Some("exp2"),
817 UnaryOp::Log => Some("log"),
818 UnaryOp::Log2 => Some("log2"),
819 UnaryOp::Sin => Some("sin"),
820 UnaryOp::Cos => Some("cos"),
821 UnaryOp::Abs => Some(if is_float { "fabs" } else { "abs" }),
822 UnaryOp::Floor => Some("floor"),
823 UnaryOp::Ceil => Some("ceil"),
824 UnaryOp::Trunc => Some("trunc"),
825 UnaryOp::Round => Some("rint"),
826 UnaryOp::Reciprocal => None,
827 UnaryOp::Tan => None,
828 UnaryOp::Sign => None,
829 UnaryOp::Erf => None,
830 UnaryOp::Square => None,
831 }
832}
833
834fn reduce_instr(op: ReduceOp, dtype: &DType) -> &'static str {
835 let is_float = dtype.is_float();
836 let is_signed = dtype.is_signed();
837
838 match op {
839 ReduceOp::Add => {
840 if is_float {
841 "fadd nsz arcp contract afn"
842 } else {
843 "add"
844 }
845 }
846 ReduceOp::Mul => {
847 if is_float {
848 "fmul nsz arcp contract afn"
849 } else {
850 "mul"
851 }
852 }
853 ReduceOp::Max => {
854 if is_float {
855 "maxnum"
856 } else if is_signed {
857 "smax"
858 } else {
859 "umax"
860 }
861 }
862 ReduceOp::Min => {
863 if is_float {
864 "minnum"
865 } else if is_signed {
866 "smin"
867 } else {
868 "umin"
869 }
870 }
871 }
872}
873
874fn mangle_type(llvm_type: &str) -> String {
875 match llvm_type {
876 "float" => "f32".to_string(),
877 "double" => "f64".to_string(),
878 "half" => "f16".to_string(),
879 "i8" => "i8".to_string(),
880 "i16" => "i16".to_string(),
881 "i32" => "i32".to_string(),
882 "i64" => "i64".to_string(),
883 _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
884 let inner = &llvm_type[1..llvm_type.len() - 1];
885 let parts: Vec<&str> = inner.split(" x ").collect();
886 if parts.len() == 2 {
887 let count = parts[0].trim();
888 let base = mangle_type(parts[1].trim());
889 format!("v{count}{base}")
890 } else {
891 llvm_type.to_string()
892 }
893 }
894 _ => llvm_type.to_string(),
895 }
896}
897
898fn render_intrinsic(dst: &str, name: &str, args: &[(&str, &str)], ret_type: &str, kernel: &mut Vec<String>) {
899 let args_str: String = args.iter().map(|(ty, val)| format!("{ty} {val}")).collect::<Vec<_>>().join(", ");
900 let mangled = mangle_type(ret_type);
901 kernel.push(format!(" {dst} = call {ret_type} @llvm.{name}.{mangled}({args_str})"));
902}
903
904fn render_binary_max(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
905 if lhs.dtype().is_float() {
906 render_intrinsic(dst, "maxnum", &[(ltype, l), (ltype, r)], ltype, kernel);
907 } else {
908 let is_signed = lhs.dtype().is_signed();
909 let cmp = if is_signed { "sgt" } else { "ugt" };
910 let cmp_dst = format!("{dst}.cmp");
911 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {l}, {r}"));
912 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {l}, {ltype} {r}"));
913 }
914}
915
916fn render_binary_pow(dst: &str, lhs: &Arc<UOp>, l: &str, r: &str, ltype: &str, kernel: &mut Vec<String>) {
917 if lhs.dtype().is_float() {
918 render_intrinsic(dst, "pow", &[(ltype, l), (ltype, r)], ltype, kernel);
919 } else {
920 let l_float = format!("{dst}.lf");
921 let r_float = format!("{dst}.rf");
922 let pow_float = format!("{dst}.pf");
923 kernel.push(format!(" {l_float} = sitofp {ltype} {l} to double"));
924 kernel.push(format!(" {r_float} = sitofp {ltype} {r} to double"));
925 render_intrinsic(&pow_float, "pow", &[("double", &l_float), ("double", &r_float)], "double", kernel);
926 kernel.push(format!(" {dst} = fptosi double {pow_float} to {ltype}"));
927 }
928}
929
930fn render_reduce_minmax(
931 dst: &str,
932 op: ReduceOp,
933 dtype: &DType,
934 acc: &str,
935 val: &str,
936 ltype: &str,
937 kernel: &mut Vec<String>,
938) {
939 if dtype.is_float() {
940 let intrinsic = match op {
941 ReduceOp::Max => "maxnum",
942 ReduceOp::Min => "minnum",
943 _ => unreachable!(),
944 };
945 render_intrinsic(dst, intrinsic, &[(ltype, acc), (ltype, val)], ltype, kernel);
946 } else {
947 let is_signed = dtype.is_signed();
948 let cmp = match op {
949 ReduceOp::Max => {
950 if is_signed {
951 "sgt"
952 } else {
953 "ugt"
954 }
955 }
956 ReduceOp::Min => {
957 if is_signed {
958 "slt"
959 } else {
960 "ult"
961 }
962 }
963 _ => unreachable!(),
964 };
965 let cmp_dst = format!("{dst}.cmp");
966 kernel.push(format!(" {cmp_dst} = icmp {cmp} {ltype} {acc}, {val}"));
967 kernel.push(format!(" {dst} = select i1 {cmp_dst}, {ltype} {acc}, {ltype} {val}"));
968 }
969}
970
971fn render_multi_gep(
972 dst: &str,
973 vec: &str,
974 vec_dtype: &DType,
975 indices: &[usize],
976 out_type: &str,
977 kernel: &mut Vec<String>,
978) {
979 let vec_type = ldt(vec_dtype);
980
981 let elem_dtype = match vec_dtype {
982 DType::Ptr { base, addrspace, size, .. } => {
983 DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: 1 }
984 }
985 DType::Vector { scalar, .. } => DType::Scalar(*scalar),
986 _ => DType::Scalar(vec_dtype.base()),
987 };
988 let elem_type = ldt(&elem_dtype);
989
990 for (i, &idx) in indices.iter().enumerate() {
991 let elem = format!("{dst}.e{i}");
992 kernel.push(format!(" {elem} = extractelement {vec_type} {vec}, i32 {idx}"));
993 }
994
995 if indices.len() == 1 {
996 kernel.push(format!(" {dst} = bitcast {elem_type} {dst}.e0 to {out_type}"));
997 } else {
998 let count = indices.len();
999 kernel.push(format!(" {dst}.undef = undef <{count} x {elem_type}>"));
1000 let mut prev = format!("{dst}.undef");
1001 for i in 0..count {
1002 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
1003 kernel.push(format!(
1004 " {next} = insertelement <{count} x {elem_type}> {prev}, {elem_type} {dst}.e{i}, i32 {i}"
1005 ));
1006 prev = next;
1007 }
1008 }
1009}
1010
1011fn render_vectorize(dst: &str, elements: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
1012 if elements.is_empty() {
1013 return;
1014 }
1015
1016 let scalar_type = ldt(&elements[0].dtype());
1017 let count = elements.len();
1018 let vec_type = format!("<{count} x {scalar_type}>");
1019
1020 let mut prev = "undef".to_string();
1021 for (i, elem) in elements.iter().enumerate() {
1022 let val = ctx.get(elem);
1023 let next = if i == count - 1 { dst.to_string() } else { format!("{dst}.v{i}") };
1024 kernel.push(format!(" {next} = insertelement {vec_type} {prev}, {scalar_type} {val}, i32 {i}"));
1025 prev = next;
1026 }
1027}
1028
1029fn render_cat(dst: &str, sources: &[Arc<UOp>], ctx: &RenderContext, kernel: &mut Vec<String>) {
1030 if sources.is_empty() {
1031 return;
1032 }
1033
1034 let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
1035 let scalar_type = ldt(&sources[0].dtype().scalar_dtype());
1036 let out_type = format!("<{total_count} x {scalar_type}>");
1037
1038 let mut out_idx = 0;
1039 let mut prev = "undef".to_string();
1040
1041 for src in sources.iter() {
1042 let src_val = ctx.get(src);
1043 let src_count = src.dtype().vcount();
1044
1045 if src_count == 1 {
1046 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
1047 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {src_val}, i32 {out_idx}"));
1048 prev = next;
1049 out_idx += 1;
1050 } else {
1051 let src_type = ldt(&src.dtype());
1052 for i in 0..src_count {
1053 let elem = format!("{dst}.e{out_idx}");
1054 kernel.push(format!(" {elem} = extractelement {src_type} {src_val}, i32 {i}"));
1055
1056 let next = if out_idx == total_count - 1 { dst.to_string() } else { format!("{dst}.c{out_idx}") };
1057 kernel.push(format!(" {next} = insertelement {out_type} {prev}, {scalar_type} {elem}, i32 {out_idx}"));
1058 prev = next;
1059 out_idx += 1;
1060 }
1061 }
1062 }
1063}
1064
1065pub fn reduce_identity(op: ReduceOp, dtype: &DType) -> String {
1067 let is_vector = matches!(dtype, DType::Vector { .. });
1068
1069 match op {
1070 ReduceOp::Add => {
1071 if is_vector {
1072 "zeroinitializer".to_string()
1073 } else if dtype.is_float() {
1074 "0.0".to_string()
1075 } else {
1076 "0".to_string()
1077 }
1078 }
1079 ReduceOp::Mul => {
1080 if is_vector {
1081 "zeroinitializer".to_string()
1082 } else if dtype.is_float() {
1083 "1.0".to_string()
1084 } else {
1085 "1".to_string()
1086 }
1087 }
1088 ReduceOp::Max => {
1089 if is_vector {
1090 "zeroinitializer".to_string()
1091 } else if dtype.is_float() {
1092 "-0x7FF0000000000000".to_string()
1093 } else if dtype.is_signed() {
1094 i64::MIN.to_string()
1095 } else {
1096 "0".to_string()
1097 }
1098 }
1099 ReduceOp::Min => {
1100 if is_vector {
1101 "zeroinitializer".to_string() } else if dtype.is_float() {
1103 "0x7FF0000000000000".to_string() } else if dtype.is_signed() {
1105 i64::MAX.to_string()
1106 } else {
1107 u64::MAX.to_string()
1108 }
1109 }
1110 }
1111}